diff --git a/.buildkite/lm-eval-harness/configs/models-small-rocm.txt b/.buildkite/lm-eval-harness/configs/models-small-rocm.txt new file mode 100644 index 000000000000..a3bb95e19e24 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/models-small-rocm.txt @@ -0,0 +1,5 @@ +Qwen2.5-1.5B-Instruct.yaml +Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml +Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml +Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml +Qwen1.5-MoE-W4A16-compressed-tensors.yaml diff --git a/tests/compile/distributed/test_fusion_all_reduce.py b/tests/compile/distributed/test_fusion_all_reduce.py index d0a194c2b044..d2d90adae3a4 100644 --- a/tests/compile/distributed/test_fusion_all_reduce.py +++ b/tests/compile/distributed/test_fusion_all_reduce.py @@ -26,15 +26,14 @@ initialize_model_parallel, ) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, - GroupShape, +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, ) from vllm.platforms import current_platform from vllm.utils.system_utils import update_environment_variables from vllm.utils.torch_utils import set_random_seed -from ...utils import has_module_attribute, multi_gpu_test +from ...utils import TestFP8Layer, has_module_attribute, multi_gpu_test from ..backend import TestBackend @@ -76,49 +75,40 @@ def ops_in_model_after(self): class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): + quant_key = kFp8StaticTensorSym + def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] - self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] - self.w = [ - torch.rand(hidden_size, hidden_size) - .to(dtype=current_platform.fp8_dtype()) - .t() - for _ in range(3) + self.fp8_linear_layers = [ + TestFP8Layer( + weight_shape=(hidden_size, hidden_size), + activation_quant_key=self.quant_key, + weight_quant_key=self.quant_key, + ) + for i in range(3) ] - self.fp8_linear = Fp8LinearOp( - act_quant_static=True, - act_quant_group_shape=GroupShape.PER_TENSOR, - ) - - self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] - def forward(self, hidden_states): # avoid having graph input be an arg to a pattern directly z = torch.relu(hidden_states) x = resid = tensor_model_parallel_all_reduce(z) y = self.norm[0](x) - z2 = self.fp8_linear.apply( - y, self.w[0], self.wscale[0], input_scale=self.scale[0] - ) + z2 = self.fp8_linear_layers[0](y) x2 = tensor_model_parallel_all_reduce(z2) y2, resid = self.norm[1](x2, resid) - z3 = self.fp8_linear.apply( - y2, self.w[1], self.wscale[1], input_scale=self.scale[1] - ) + z3 = self.fp8_linear_layers[1](y2) x3 = tensor_model_parallel_all_reduce(z3) y3, resid = self.norm[2](x3, resid) # use resid here - z4 = self.fp8_linear.apply( - y3, self.w[2], self.wscale[2], input_scale=self.scale[2] - ) + z4 = self.fp8_linear_layers[2](y3) + x4 = tensor_model_parallel_all_reduce(z4) y4, resid = self.norm[3](x4, resid) # use resid here return y4 @@ -130,7 +120,7 @@ def ops_in_model_before(self): return [ torch.ops.vllm.all_reduce.default, torch.ops._C.static_scaled_fp8_quant.default - if self.fp8_linear.quant_fp8.enabled() + if self.fp8_linear_layers[0].is_quant_fp8_enabled() else torch.ops.aten.reciprocal.default, ] diff --git a/tests/compile/distributed/test_sequence_parallelism.py b/tests/compile/distributed/test_sequence_parallelism.py index 35916ba99652..d8a1a4288dc6 100644 --- a/tests/compile/distributed/test_sequence_parallelism.py +++ b/tests/compile/distributed/test_sequence_parallelism.py @@ -27,13 +27,14 @@ initialize_model_parallel, ) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape -from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, +) from vllm.platforms import current_platform from vllm.utils.system_utils import update_environment_variables from vllm.utils.torch_utils import set_random_seed -from ...utils import multi_gpu_test +from ...utils import TestFP8Layer, multi_gpu_test from ..backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() @@ -94,50 +95,40 @@ def ops_in_model(self): class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): + quant_key = kFp8StaticTensorSym + def __init__(self, hidden_size=16, eps=1e-6): super().__init__() self.vllm_config = get_current_vllm_config() self.hidden_size = hidden_size self.eps = eps self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] - self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] - self.w = [ - torch.rand(hidden_size, hidden_size) - .to(dtype=current_platform.fp8_dtype()) - .t() - for _ in range(3) + self.fp8_linear_layers = [ + TestFP8Layer( + weight_shape=(hidden_size, hidden_size), + activation_quant_key=self.quant_key, + weight_quant_key=self.quant_key, + ) + for i in range(3) ] - self.fp8_linear = Fp8LinearOp( - act_quant_static=True, - act_quant_group_shape=GroupShape.PER_TENSOR, - ) - - self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] - def forward(self, hidden_states): # avoid having graph input be an arg to a pattern directly z = torch.relu(hidden_states) x = resid = tensor_model_parallel_all_reduce(z) y = self.norm[0](x) - z2 = self.fp8_linear.apply( - y, self.w[0], self.wscale[0], input_scale=self.scale[0] - ) + z2 = self.fp8_linear_layers[0](y) x2 = tensor_model_parallel_all_reduce(z2) y2, resid = self.norm[1](x2, resid) - z3 = self.fp8_linear.apply( - y2, self.w[1], self.wscale[1], input_scale=self.scale[1] - ) + z3 = self.fp8_linear_layers[1](y2) x3 = tensor_model_parallel_all_reduce(z3) y3, resid = self.norm[2](x3, resid) # use resid here - z4 = self.fp8_linear.apply( - y3, self.w[2], self.wscale[2], input_scale=self.scale[2] - ) + z4 = self.fp8_linear_layers[2](y3) x4 = tensor_model_parallel_all_reduce(z4) y4, resid = self.norm[3](x4, resid) # use resid here return y4 @@ -160,7 +151,7 @@ def ops_in_model(self): return [ torch.ops._C.fused_add_rms_norm.default, ] - elif self.fp8_linear.quant_fp8.enabled(): + elif any(layer.is_quant_fp8_enabled() for layer in self.fp8_linear_layers): return [ torch.ops._C.static_scaled_fp8_quant.default, ] diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index ad5ead36e231..9791493fd164 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -20,11 +20,13 @@ ) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape -from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.platforms import current_platform +from ..utils import TestFP8Layer from .backend import TestBackend TEST_FP8 = current_platform.supports_fp8() @@ -32,24 +34,22 @@ class TestSiluMul(torch.nn.Module): + quant_key = kFp8StaticTensorSym + def __init__(self, hidden_size: int = 128): super().__init__() self.silu_and_mul = SiluAndMul() - self.wscale = torch.rand(1, dtype=torch.float32) - self.scale = torch.rand(1, dtype=torch.float32) - if TEST_FP8: - self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() - self.fp8_linear = Fp8LinearOp( - act_quant_static=True, - act_quant_group_shape=GroupShape.PER_TENSOR, + self.fp8_linear = TestFP8Layer( + weight_shape=(hidden_size, hidden_size), + activation_quant_key=self.quant_key, + weight_quant_key=self.quant_key, ) def forward(self, x): y = self.silu_and_mul(x) if TEST_FP8: - x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale) - return x2 + return self.fp8_linear(y) else: return y @@ -67,6 +67,8 @@ def ops_not_in_model(self): class TestFusedAddRMSNorm(torch.nn.Module): + quant_key = kFp8StaticTensorSym + def __init__(self, hidden_size=16, intermediate_size=32): super().__init__() self.hidden_size = hidden_size @@ -81,11 +83,11 @@ def __init__(self, hidden_size=16, intermediate_size=32): torch.nn.init.normal_(self.gate_proj, std=0.02) if TEST_FP8: - self.fp8_linear = Fp8LinearOp(act_quant_static=True) - - self.scale = torch.rand(1, dtype=torch.float32) - self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() - self.wscale = torch.rand(1, dtype=torch.float32) + self.fp8_linear = TestFP8Layer( + weight_shape=(hidden_size, intermediate_size), + activation_quant_key=self.quant_key, + weight_quant_key=self.quant_key, + ) def forward(self, hidden_states, residual): # Reshape input @@ -100,12 +102,7 @@ def forward(self, hidden_states, residual): if TEST_FP8: # scaled_mm with static input quantization - fp8_linear_result = self.fp8_linear.apply( - norm_output, - self.w, - self.wscale, - input_scale=self.scale.to(norm_output.device), - ) + fp8_linear_result = self.fp8_linear(norm_output) return fp8_linear_result, residual_output diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 7755e9f9b738..26afb0203369 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -5,6 +5,7 @@ import pytest import torch +import vllm.config import vllm.plugins from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass @@ -20,8 +21,22 @@ VllmConfig, ) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - W8A8BlockFp8LinearOp, +from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( + CutlassFP8ScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import ( + FlashInferFP8ScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import ( + ChannelWiseTorchFP8ScaledMMLinearKernel, + PerTensorTorchFP8ScaledMMLinearKernel, + RowWiseTorchFP8ScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import ( + ROCmFP8ScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + FP8ScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, @@ -29,15 +44,14 @@ ScaleDesc, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, cutlass_block_fp8_supported, - cutlass_fp8_supported, - maybe_create_device_identity, ) from vllm.platforms import current_platform -from vllm.utils.deep_gemm import is_deep_gemm_supported +from vllm.utils.deep_gemm import ( + is_deep_gemm_supported, +) -from ..utils import override_cutlass_fp8_supported +from ..utils import TestBlockFP8Layer, TestFP8Layer from .backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() @@ -45,157 +59,195 @@ RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default +# Kernel and group_shape combinations: (kernel, group_shape) +# CUDA kernels +CUDA_KERNEL_GROUPSHAPE_COMBINATIONS = [ + # FlashInferFP8ScaledMMLinearKernel supports both per-tensor only + (FlashInferFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR), + # CutlassFP8ScaledMMLinearKernel supports both per-tensor and per-token + (CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN), + (CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR), + # PerTensorTorchFP8ScaledMMLinearKernel only supports per-tensor + (PerTensorTorchFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR), + # ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token + (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN), + # Blockwise group shapes (no kernel abstraction) + (None, GroupShape(1, 128)), + (None, GroupShape(1, 64)), +] + +# ROCm kernels +ROCM_KERNEL_GROUPSHAPE_COMBINATIONS = [ + # ROCmFP8ScaledMMLinearKernel supports per-tensor only + (ROCmFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR), + # RowWiseTorchFP8ScaledMMLinearKernel only supports per-token + (RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN), + # ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token + (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN), + # Blockwise group shapes (no kernel abstraction) + (None, GroupShape(1, 128)), + (None, GroupShape(1, 64)), +] + +KERNEL_GROUPSHAPE_COMBINATIONS = ( + CUDA_KERNEL_GROUPSHAPE_COMBINATIONS + if current_platform.is_cuda() + else ROCM_KERNEL_GROUPSHAPE_COMBINATIONS +) + +# For Aiter tests we toggle use_aiter_quant_op +AITER_KERNEL_GROUPSHAPE_COMBINATIONS = [ + # Per-token with ROCmFP8ScaledMMLinearKernel + (ROCmFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR, False), + # Per-token with RowWiseTorchFP8ScaledMMLinearKernel + (RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True), + (RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False), + # Per-token with ChannelWiseTorchFP8ScaledMMLinearKernel + (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True), + (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False), + # Blockwise (no kernel abstraction) + (None, GroupShape(1, 128), True), +] + class TestModel(torch.nn.Module): def __init__( self, hidden_size: int, eps: float, + force_kernel: FP8ScaledMMLinearKernel | None, group_shape: GroupShape, - use_aiter: bool = False, - cuda_force_torch: bool = False, - use_aiter_quant_op: bool = True, + use_aiter_fusion: bool = False, + use_aiter_quant: bool = False, *args, **kwargs, ): super().__init__(*args, **kwargs) - self.use_aiter = use_aiter - self.use_aiter_quant_op = use_aiter_quant_op - self.cuda_force_torch = cuda_force_torch + self.fp8_linear_layers: list[torch.nn.Module] self.group_shape = group_shape - self.enable_quant_fp8_custom_op = None # Will be set later if applicable - + self.use_aiter_quant_op = use_aiter_quant + self.use_aiter_fusion = use_aiter_fusion self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] + self.enable_rms_norm_custom_op = self.norm[0].enabled() - # Setup quantization scale descriptor - static = group_shape == GroupShape.PER_TENSOR and not use_aiter - quant_scale = ScaleDesc(torch.float32, static, group_shape) - self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) + # Determine if blockwise based on group_shape + is_blockwise = group_shape.is_per_group() - # Setup scales - if static: - self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] - else: - self.scale = [None for _ in range(3)] + if is_blockwise: + act_quant_scale_desc = ScaleDesc(torch.float32, False, group_shape) + self.activation_quant_key = QuantKey( + dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True + ) + self.fp8_linear_layers = [ + TestBlockFP8Layer( + weight_shape=(hidden_size, hidden_size), + group_shape=group_shape, + cutlass_block_fp8_supported=cutlass_block_fp8_supported(), + use_aiter_and_is_supported=use_aiter_quant, + transpose_weights=use_aiter_fusion, + ) + for _ in range(3) + ] - # Setup weights - self.w = [ - torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3) - ] - if not group_shape.is_per_group() or use_aiter: - self.w = [self.w[0].t() for _ in range(3)] - - # Setup weight scales - if group_shape.is_per_group(): - scale_size = ( - (hidden_size + 128 - 1) // 128 - if use_aiter - else hidden_size // group_shape[1] + self.enable_quant_fp8_custom_op = ( + False + if use_aiter_quant + else self.fp8_linear_layers[0].linear_op.input_quant_op.enabled() ) - wscale_shape: tuple[int, ...] = (scale_size, scale_size) + else: - wscale_shape = (1,) - self.wscale = [torch.rand(wscale_shape, dtype=torch.float32) for _ in range(3)] - - # Setup FP8 linear operation - is_per_group = group_shape.is_per_group() - if is_per_group and use_aiter: - self.fp8_linear = W8A8BlockFp8LinearOp( - weight_group_shape=GroupShape(128, 128), - act_quant_group_shape=group_shape, - use_aiter_and_is_supported=use_aiter_quant_op, - ) - # AITER blockwise doesn't use enable_quant_fp8_custom_op - elif is_per_group: - self.fp8_linear = W8A8BlockFp8LinearOp( - weight_group_shape=GroupShape(group_shape[1], group_shape[1]), - act_quant_group_shape=group_shape, - cutlass_block_fp8_supported=cutlass_block_fp8_supported(), - use_aiter_and_is_supported=False, + is_static = group_shape == GroupShape.PER_TENSOR + act_quant_scale_desc = ScaleDesc(torch.float32, is_static, group_shape) + w_quant_scale_desc = ScaleDesc(torch.float32, True, group_shape) + self.activation_quant_key = QuantKey( + dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True ) - self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled() - elif use_aiter: - self.fp8_linear = Fp8LinearOp( - act_quant_static=False, - act_quant_group_shape=group_shape, + self.weight_quant_key = QuantKey( + dtype=FP8_DTYPE, scale=w_quant_scale_desc, symmetric=True ) - self.fp8_linear.quant_fp8.use_aiter = use_aiter_quant_op - self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() - else: - with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear = Fp8LinearOp( - act_quant_static=static, - act_quant_group_shape=group_shape, + self.fp8_linear_layers = [ + TestFP8Layer( + weight_shape=(hidden_size, hidden_size), + activation_quant_key=self.activation_quant_key, + weight_quant_key=self.weight_quant_key, + force_kernel=force_kernel, ) - self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() + for _ in range(3) + ] - self.enable_rms_norm_custom_op = self.norm[0].enabled() + # Enable aiter quantization if requested + for layer in self.fp8_linear_layers: + layer.kernel.quant_fp8.use_aiter = use_aiter_quant + + self.enable_quant_fp8_custom_op = self.fp8_linear_layers[ + 0 + ].is_quant_fp8_enabled() def forward(self, x): # avoid having graph input be an arg to a pattern directly x = resid = torch.relu(x) y = self.norm[0](x) - x2 = self.fp8_linear.apply( - y, self.w[0], self.wscale[0], input_scale=self.scale[0] - ) + x2 = self.fp8_linear_layers[0](y) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - x3 = self.fp8_linear.apply( - y2, self.w[1], self.wscale[1], input_scale=self.scale[1] - ) + x3 = self.fp8_linear_layers[1](y2) y3, resid = self.norm[2](x3, resid) # use resid here - x4 = self.fp8_linear.apply( - y3, self.w[2], self.wscale[2], input_scale=self.scale[2] - ) + x4 = self.fp8_linear_layers[2](y3) y4, resid = self.norm[3](x4, resid) # use resid here return y4 def ops_in_model_before(self): - if ( - self.use_aiter - and self.group_shape.is_per_group() - and current_platform.is_fp8_fnuz() - ): - return [rocm_aiter_ops.get_group_quant_op()] - if self.use_aiter and self.group_shape.is_per_group(): - return [torch.ops.vllm.triton_per_token_group_quant_fp8.default] - if self.use_aiter and self.use_aiter_quant_op: - return [rocm_aiter_ops.get_per_token_quant_op()] - if self.use_aiter: - return [QUANT_OPS[self.quant_key]] - if self.enable_quant_fp8_custom_op: - return [QUANT_OPS[self.quant_key]] - return [torch.ops.aten.reciprocal] + if self.group_shape.is_per_group(): + # Blockwise path + if self.use_aiter_fusion and self.use_aiter_quant_op: + return [rocm_aiter_ops.get_group_quant_op()] + if self.use_aiter_fusion: + return [torch.ops.vllm.triton_per_token_group_quant_fp8.default] + else: + if self.use_aiter_quant_op: + return [rocm_aiter_ops.get_per_token_quant_op()] + + # Common path + return ( + [QUANT_OPS[self.activation_quant_key]] + if self.enable_quant_fp8_custom_op + else [torch.ops.aten.reciprocal] + ) def ops_in_model_after(self): - if self.use_aiter and self.group_shape.is_per_group(): - from vllm.compilation.rocm_aiter_fusion import ( - AiterFusedAddRMSFp8GroupQuantPattern, - AiterRMSFp8GroupQuantPattern, - ) + if self.use_aiter_fusion: + if self.group_shape.is_per_group(): + # Blockwise aiter fusion + from vllm.compilation.rocm_aiter_fusion import ( + AiterFusedAddRMSFp8GroupQuantPattern, + AiterRMSFp8GroupQuantPattern, + ) - return [ - AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP, - AiterRMSFp8GroupQuantPattern.FUSED_OP, - ] - if self.use_aiter: - from vllm.compilation.rocm_aiter_fusion import ( - AiterFusedAddRMSNormDynamicQuantPattern, - AiterRMSNormDynamicQuantPattern, - ) + return [ + AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP, + AiterRMSFp8GroupQuantPattern.FUSED_OP, + ] + else: + # Per-token aiter fusion + from vllm.compilation.rocm_aiter_fusion import ( + AiterFusedAddRMSNormDynamicQuantPattern, + AiterRMSNormDynamicQuantPattern, + ) - return [ - AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP, - AiterRMSNormDynamicQuantPattern.FUSED_OP, - ] + return [ + AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP, + AiterRMSNormDynamicQuantPattern.FUSED_OP, + ] + + # Regular fusion return [ - FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)], - FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)], + FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, True)], + FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, False)], ] def ops_in_model_before_partial(self): @@ -206,14 +258,6 @@ def ops_in_model_before_partial(self): ) -GROUP_SHAPES = [ - GroupShape.PER_TOKEN, - GroupShape.PER_TENSOR, - GroupShape(1, 128), - GroupShape(1, 64), -] - - def _run_fusion_test( model, fusion_pass, @@ -259,14 +303,9 @@ def _run_fusion_test( @pytest.mark.parametrize("hidden_size", [256]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) -@pytest.mark.parametrize("group_shape", GROUP_SHAPES) +@pytest.mark.parametrize("kernel_groupshape", KERNEL_GROUPSHAPE_COMBINATIONS) @pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) @pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False]) -# cuda_force_torch used to test torch code path on platforms that -# cutlass_fp8_supported() == True. -@pytest.mark.parametrize( - "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True] -) @pytest.mark.skipif( not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm" ) @@ -275,11 +314,12 @@ def test_fusion_rmsnorm_quant( hidden_size, num_tokens, eps, - group_shape, + kernel_groupshape, enable_rms_norm_custom_op, enable_quant_fp8_custom_op, - cuda_force_torch, ): + force_kernel, group_shape = kernel_groupshape + if not enable_quant_fp8_custom_op and group_shape.is_per_group(): pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization") @@ -310,15 +350,16 @@ def test_fusion_rmsnorm_quant( torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) - maybe_create_device_identity() fusion_pass = RMSNormQuantFusionPass(vllm_config) + model = TestModel( hidden_size=hidden_size, eps=eps, + force_kernel=force_kernel, group_shape=group_shape, - use_aiter=False, - cuda_force_torch=cuda_force_torch, + use_aiter_fusion=False, + use_aiter_quant=False, ) backend, _ = _run_fusion_test( @@ -339,19 +380,12 @@ def test_fusion_rmsnorm_quant( assert n_add_nodes(backend.graph_post_pass) == 2 -GROUP_SHAPE_QUANT_OPS_MATCHS = [ - (GroupShape.PER_TOKEN, True), - (GroupShape.PER_TOKEN, False), - (GroupShape(1, 128), True), -] - - @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [256]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize( - "group_shape, use_aiter_quant_op", GROUP_SHAPE_QUANT_OPS_MATCHS + "kernel_groupshape_quant", AITER_KERNEL_GROUPSHAPE_COMBINATIONS ) @pytest.mark.skipif( (not current_platform.is_rocm() or not IS_AITER_FOUND), @@ -362,10 +396,10 @@ def test_aiter_fusion_rmsnorm_quant( hidden_size: int, num_tokens: int, eps: float, - group_shape: GroupShape, - use_aiter_quant_op: bool, + kernel_groupshape_quant: tuple, monkeypatch: pytest.MonkeyPatch, ): + force_kernel, group_shape, use_aiter_quant_op = kernel_groupshape_quant vllm_config = VllmConfig( model_config=ModelConfig(dtype=dtype), compilation_config=CompilationConfig( @@ -379,20 +413,22 @@ def test_aiter_fusion_rmsnorm_quant( from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormFusionPass m.setenv("VLLM_ROCM_USE_AITER", "1") + rocm_aiter_ops.refresh_env_variables() torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) - maybe_create_device_identity() fusion_pass = RocmAiterRMSNormFusionPass(vllm_config) + model = TestModel( hidden_size=hidden_size, eps=eps, + force_kernel=force_kernel, group_shape=group_shape, - use_aiter=True, - use_aiter_quant_op=use_aiter_quant_op, + use_aiter_fusion=True, # Always use aiter fusion ops in aiter test + use_aiter_quant=use_aiter_quant_op, # Toggle aiter quantization ) _run_fusion_test( diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index ef6978a7ffd2..1f1dafcca55b 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -45,7 +45,6 @@ kFp8StaticTensorSym, kNvfp4Quant, ) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer from vllm.utils.torch_utils import is_torch_equal_or_newer @@ -53,6 +52,8 @@ from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.kv_cache_interface import AttentionSpec +from ..utils import TestFP8Layer + FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 @@ -185,32 +186,30 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.quant_key.scale.static, - act_quant_group_shape=self.quant_key.scale.group_shape, - ) - hidden_size = self.num_qo_heads * self.head_size - self.w = kwargs.get( - "w", - { - "weight": torch.randn(hidden_size, hidden_size) - .to(dtype=FP8_DTYPE, device=self.device) - .t(), - "wscale": torch.tensor([1.0], dtype=torch.float32, device=self.device), - "scale": torch.tensor([1.0], dtype=torch.float32, device=self.device), - }, + self.fp8_linear = TestFP8Layer( + weight_shape=(hidden_size, hidden_size), + activation_quant_key=self.quant_key, + weight_quant_key=self.quant_key, + device=self.device, ) + w = kwargs.get("w") + if w is not None: + self.fp8_linear.weight = w["weight"] + self.fp8_linear.weight_scale = w["wscale"] + self.fp8_linear.input_scale = w["scale"] + + self.w = { + "weight": self.fp8_linear.weight, + "wscale": self.fp8_linear.weight_scale, + "scale": self.fp8_linear.input_scale, + } + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): """Forward pass that creates the pattern to be fused.""" attn_output = self.attn(q, k, v) - return self.fp8_linear.apply( - input=attn_output, - weight=self.w["weight"], - weight_scale=self.w["wscale"], - input_scale=self.w["scale"], - ) + return self.fp8_linear(attn_output) class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index eb0dee8d4e39..b71487275144 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -25,19 +25,30 @@ set_current_vllm_config, ) from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( + CutlassFP8ScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import ( + FlashInferFP8ScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import ( + PerTensorTorchFP8ScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import ( + ROCmFP8ScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + FP8ScaledMMLinearKernel, +) from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, kFp8StaticTensorSym, kNvfp4Quant, ) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, - maybe_create_device_identity, -) from vllm.platforms import current_platform -from ..utils import override_cutlass_fp8_supported +from ..utils import TestFP8Layer from .backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() @@ -49,25 +60,27 @@ def is_nvfp4_supported(): class TestSiluMulFp8QuantModel(torch.nn.Module): - def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs): + quant_key = kFp8StaticTensorSym + + def __init__( + self, hidden_size: int, force_kernel: FP8ScaledMMLinearKernel, **kwargs + ): super().__init__() self.silu_and_mul = SiluAndMul() - self.wscale = torch.rand(1, dtype=torch.float32) - self.scale = torch.rand(1, dtype=torch.float32) - self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() + self.fp8_linear = TestFP8Layer( + weight_shape=(hidden_size, hidden_size), + activation_quant_key=self.quant_key, + weight_quant_key=self.quant_key, + force_kernel=force_kernel, + ) - with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear = Fp8LinearOp( - act_quant_static=True, - act_quant_group_shape=GroupShape.PER_TENSOR, - ) self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() - self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() + self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled() def forward(self, x): y = self.silu_and_mul(x) - x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale) + x2 = self.fp8_linear(y) return x2 def ops_in_model_before(self): @@ -161,20 +174,27 @@ def ops_in_model_after(self): return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant] +ROCM_KERNELS = [ROCmFP8ScaledMMLinearKernel, PerTensorTorchFP8ScaledMMLinearKernel] +CUDA_KERNELS = [ + FlashInferFP8ScaledMMLinearKernel, + CutlassFP8ScaledMMLinearKernel, + PerTensorTorchFP8ScaledMMLinearKernel, +] +TEST_KERNELS = ROCM_KERNELS if current_platform.is_rocm() else CUDA_KERNELS + + @pytest.mark.parametrize("num_tokens", [32, 64]) @pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("enable_silu_mul_custom_op", [True, False]) @pytest.mark.parametrize( - "model_class, enable_quant_fp8_custom_op, cuda_force_torch", - list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False])) + "model_class, enable_quant_fp8_custom_op, force_kernel", + list(itertools.product([TestSiluMulFp8QuantModel], [True, False], TEST_KERNELS)) + [ - (TestSiluMulNvfp4QuantModel, False, False), - (TestSiluMulGroupFp8QuantModel, False, False), + (TestSiluMulNvfp4QuantModel, False, None), + (TestSiluMulGroupFp8QuantModel, False, None), ], ) -# cuda_force_torch used to test torch code path on platforms that -# cutlass_fp8_supported() == True. @pytest.mark.skipif( envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm" ) @@ -189,7 +209,7 @@ def test_fusion_silu_and_mul_quant( ], enable_silu_mul_custom_op: bool, enable_quant_fp8_custom_op: bool, - cuda_force_torch: bool, + force_kernel: FP8ScaledMMLinearKernel | None, ): if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported(): pytest.skip("NVFP4 is not supported on this GPU.") @@ -198,7 +218,6 @@ def test_fusion_silu_and_mul_quant( torch.set_default_device("cuda") torch.set_default_dtype(dtype) - maybe_create_device_identity() x = torch.rand(num_tokens, hidden_size * 2) @@ -227,9 +246,7 @@ def test_fusion_silu_and_mul_quant( passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)] backend = TestBackend(*passes) - model = model_class( - hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x - ) + model = model_class(hidden_size=hidden_size, force_kernel=force_kernel, x=x) # First dimension dynamic torch._dynamo.mark_dynamic(x, 0) diff --git a/tests/kernels/quantization/test_scaled_mm_kernel_selection.py b/tests/kernels/quantization/test_scaled_mm_kernel_selection.py index 2ed55931c816..1de8c444cf76 100644 --- a/tests/kernels/quantization/test_scaled_mm_kernel_selection.py +++ b/tests/kernels/quantization/test_scaled_mm_kernel_selection.py @@ -11,13 +11,13 @@ import pytest from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - ScaledMMLinearLayerConfig, + Int8ScaledMMLinearLayerConfig, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( - AiterScaledMMLinearKernel, + AiterInt8ScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import ( - CPUScaledMMLinearKernel, + CPUInt8ScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 ScaledMMLinearKernel, @@ -33,36 +33,38 @@ def test_is_supported_is_abstract(): def test_cpu_kernel_implements_is_supported(): - """Test that CPUScaledMMLinearKernel implements is_supported() method.""" - assert hasattr(CPUScaledMMLinearKernel, "is_supported"), ( - "CPUScaledMMLinearKernel missing is_supported() method" + """Test that CPUInt8ScaledMMLinearKernel implements is_supported() method.""" + assert hasattr(CPUInt8ScaledMMLinearKernel, "is_supported"), ( + "CPUInt8ScaledMMLinearKernel missing is_supported() method" ) # Verify it's a classmethod by checking if it can be called with the class # and by checking the method type - assert inspect.ismethod(CPUScaledMMLinearKernel.is_supported) or inspect.isfunction( - CPUScaledMMLinearKernel.is_supported - ), "CPUScaledMMLinearKernel.is_supported() should be a classmethod" + assert inspect.ismethod( + CPUInt8ScaledMMLinearKernel.is_supported + ) or inspect.isfunction(CPUInt8ScaledMMLinearKernel.is_supported), ( + "CPUInt8ScaledMMLinearKernel.is_supported() should be a classmethod" + ) # Verify it can be called as a classmethod - result, reason = CPUScaledMMLinearKernel.is_supported() + result, reason = CPUInt8ScaledMMLinearKernel.is_supported() assert isinstance(result, bool), "is_supported() should return a bool" assert reason is None or isinstance(reason, str), "reason should be str or None" def test_aiter_kernel_implements_is_supported(): - """Test that AiterScaledMMLinearKernel implements is_supported() method.""" - assert hasattr(AiterScaledMMLinearKernel, "is_supported"), ( - "AiterScaledMMLinearKernel missing is_supported() method" + """Test that AiterInt8ScaledMMLinearKernel implements is_supported() method.""" + assert hasattr(AiterInt8ScaledMMLinearKernel, "is_supported"), ( + "AiterInt8ScaledMMLinearKernel missing is_supported() method" ) # Verify it's a classmethod by checking if it can be called with the class # and by checking the method type assert inspect.ismethod( - AiterScaledMMLinearKernel.is_supported - ) or inspect.isfunction(AiterScaledMMLinearKernel.is_supported), ( - "AiterScaledMMLinearKernel.is_supported() should be a classmethod" + AiterInt8ScaledMMLinearKernel.is_supported + ) or inspect.isfunction(AiterInt8ScaledMMLinearKernel.is_supported), ( + "AiterInt8ScaledMMLinearKernel.is_supported() should be a classmethod" ) # Verify it can be called as a classmethod # (will return False on CPU, which is expected) - result, reason = AiterScaledMMLinearKernel.is_supported() + result, reason = AiterInt8ScaledMMLinearKernel.is_supported() assert isinstance(result, bool), "is_supported() should return a bool" assert reason is None or isinstance(reason, str), "reason should be str or None" # On CPU, it should return False with a reason about requiring ROCm @@ -70,14 +72,14 @@ def test_aiter_kernel_implements_is_supported(): def test_cpu_kernel_accepts_all_configs(): - """Test that CPUScaledMMLinearKernel accepts all config combinations.""" + """Test that CPUInt8ScaledMMLinearKernel accepts all config combinations.""" configs = [ - ScaledMMLinearLayerConfig( + Int8ScaledMMLinearLayerConfig( is_channelwise=False, is_static_input_scheme=True, input_symmetric=True, ), - ScaledMMLinearLayerConfig( + Int8ScaledMMLinearLayerConfig( is_channelwise=True, is_static_input_scheme=False, input_symmetric=False, @@ -85,7 +87,7 @@ def test_cpu_kernel_accepts_all_configs(): ] for config in configs: - can_impl, reason = CPUScaledMMLinearKernel.can_implement(config) + can_impl, reason = CPUInt8ScaledMMLinearKernel.can_implement(config) assert can_impl, ( - f"CPUScaledMMLinearKernel should accept config {config}: {reason}" + f"CPUInt8ScaledMMLinearKernel should accept config {config}: {reason}" ) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 7fb38708c353..085c2a703e4c 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -41,7 +41,7 @@ "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", ] -# TritonScaledMMLinearKernel only supports symmetric quantization. +# TritonInt8ScaledMMLinearKernel only supports symmetric quantization. ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [ "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", diff --git a/tests/utils.py b/tests/utils.py index c186f5c47a59..b3d21ac55b75 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -42,6 +42,17 @@ ) from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.cli.serve import ServeSubcommand +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + FP8ScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + QuantKey, +) from vllm.model_executor.model_loader import get_model_loader from vllm.platforms import current_platform from vllm.tokenizers import get_tokenizer @@ -50,6 +61,8 @@ from vllm.utils.network_utils import get_open_port from vllm.utils.torch_utils import cuda_device_count_stateless +FP8_DTYPE = current_platform.fp8_dtype() + if current_platform.is_rocm(): from amdsmi import ( amdsmi_get_gpu_vram_usage, @@ -1332,3 +1345,117 @@ def flat_product(*iterables: Iterable[Any]): for element in itertools.product(*iterables): normalized = (e if isinstance(e, tuple) else (e,) for e in element) yield tuple(itertools.chain(*normalized)) + + +class TestFP8Layer(torch.nn.Module): + """ + Test helper for FP8 linear operations. Creates random weights and scales + based on quantization configuration. + + Args: + weight_shape: Shape of the weight tensor (out_features, in_features). + activation_quant_key: Activation quantization configuration. + weight_quant_key: Weight quantization configuration. + out_dtype: Output dtype. Defaults to current default dtype. + force_kernel: Optional kernel to force use of specific implementation. + """ + + def __init__( + self, + weight_shape: tuple[int, int], + activation_quant_key: QuantKey, + weight_quant_key: QuantKey, + out_dtype: torch.dtype | None = None, + device: torch.device | None = None, + force_kernel: FP8ScaledMMLinearKernel | None = None, + ): + super().__init__() + per_tensor_weights = weight_quant_key.scale.group_shape.is_per_tensor() + is_static_activation_scale = activation_quant_key.scale.static + weight_scale_shape = (1,) if per_tensor_weights else (weight_shape[0], 1) + + self.weight_scale = torch.rand( + weight_scale_shape, dtype=torch.float32, device=device + ) + self.input_scale = ( + torch.rand(1, dtype=torch.float32, device=device) + if is_static_activation_scale + else None + ) + self.weight = torch.rand(weight_shape, device=device).to(dtype=FP8_DTYPE).t() + self.input_scale_ub = None + + out_dtype = torch.get_default_dtype() if out_dtype is None else out_dtype + + self.kernel = init_fp8_linear_kernel( + activation_quant_key=activation_quant_key, + weight_quant_key=weight_quant_key, + out_dtype=out_dtype, + force_kernel=force_kernel, + ) + + def is_quant_fp8_enabled(self) -> bool: + return self.kernel.quant_fp8.enabled() + + def forward( + self, y: torch.Tensor, bias: torch.Tensor | None = None + ) -> torch.Tensor: + return self.kernel.apply_weights(self, y, bias) + + +# TODO: Drop TestBlockFP8Layer in favour of a unified TestFP8Layer +# after refactoring W8A8BlockFp8LinearOp. +# https://github.com/vllm-project/vllm/issues/31818 +class TestBlockFP8Layer: + """ + Test helper for blockwise FP8 linear operations. Creates random weights + and scales for W8A8BlockFp8LinearOp. + + This is a workaround until W8A8BlockFp8LinearOp implements the kernel + abstraction (ScaledMMLinearKernel) for blockwise quantization. + + Args: + weight_shape: Shape of the weight tensor (out_features, in_features). + group_shape: Blockwise quantization group shape. + cutlass_block_fp8_supported: Whether CUTLASS blockwise FP8 is available. + use_aiter_and_is_supported: Whether to use aiter quantization ops. + transpose_weights: Whether to transpose weights after creation. + """ + + def __init__( + self, + weight_shape: tuple[int, int], + group_shape: GroupShape, + cutlass_block_fp8_supported: bool = False, + use_aiter_and_is_supported: bool = False, + transpose_weights: bool = False, + ): + weight_scale_shape = weight_shape[0] // group_shape[1] + self.weight_scale = torch.rand( + (weight_scale_shape, weight_scale_shape), dtype=torch.float32 + ) + self.weight = torch.rand(weight_shape).to(dtype=FP8_DTYPE) + self.input_scale = None + if transpose_weights: + self.weight = self.weight.t() + + self.linear_op = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(group_shape[1], group_shape[1]), + act_quant_group_shape=group_shape, + cutlass_block_fp8_supported=cutlass_block_fp8_supported, + use_aiter_and_is_supported=use_aiter_and_is_supported, + ) + + def __call__( + self, y: torch.Tensor, bias: torch.Tensor | None = None + ) -> torch.Tensor: + return self.linear_op.apply( + input=y, + weight=self.weight, + weight_scale=self.weight_scale, + input_scale=self.input_scale, + bias=bias, + ) + + def is_quant_fp8_enabled(self) -> bool: + return self.linear_op.input_quant_op.enabled() diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 2a247c6d53f2..2e0c4a69c82f 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -372,7 +372,7 @@ def _rocm_aiter_gemm_a8w8_impl( # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects # a to be [M, K] # b to be [N, K] - # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format + # CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 0c303b335935..1120202f29fd 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -8,9 +8,13 @@ from torch.nn import Parameter from vllm._aiter_ops import rocm_aiter_ops +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, create_fp8_input_scale, @@ -22,11 +26,14 @@ process_fp8_weight_tensor_strategy, validate_fp8_block_shape, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, + kFp8StaticTokenSym, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, cutlass_block_fp8_supported, - maybe_create_device_identity, ) from vllm.model_executor.parameter import ( BlockQuantScaleParameter, @@ -42,6 +49,18 @@ QuantizationStrategy.TENSOR: PerTensorScaleParameter, } +STATIC_QUANT = True +DYNAMIC_QUANT = False +activation_quant_key_mapping = { + STATIC_QUANT: kFp8StaticTensorSym, + DYNAMIC_QUANT: kFp8DynamicTokenSym, +} +weight_quant_key_mapping = { + QuantizationStrategy.CHANNEL: kFp8StaticTokenSym, + QuantizationStrategy.TENSOR: kFp8StaticTensorSym, +} +logger = init_logger(__name__) + class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool): @@ -49,22 +68,13 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) self.strategy = weight_quant.strategy self.out_dtype = torch.get_default_dtype() self.is_static_input_scheme = is_static_input_scheme - self.weight_block_size = self.weight_quant.block_structure - if self.weight_block_size is not None: - self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) - else: - self.act_q_group_shape = ( - GroupShape.PER_TENSOR - if is_static_input_scheme - else GroupShape.PER_TOKEN - ) - - self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() - self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled() if self.weight_block_size is not None: + self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() + self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled() assert not self.is_static_input_scheme + self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( weight_group_shape=GroupShape(*self.weight_block_size), act_quant_group_shape=self.act_q_group_shape, @@ -72,9 +82,13 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) else: - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.is_static_input_scheme, - act_quant_group_shape=self.act_q_group_shape, + activation_quant_key = activation_quant_key_mapping[is_static_input_scheme] + weight_quant_key = weight_quant_key_mapping[self.strategy] + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=activation_quant_key, + weight_quant_key=weight_quant_key, + out_dtype=self.out_dtype, + module_name=self.__class__.__name__, ) @classmethod @@ -93,8 +107,6 @@ def create_weights( weight_loader: Callable, **kwargs, ): - maybe_create_device_identity() - output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes layer.weight_block_size = None @@ -143,7 +155,6 @@ def process_weights_after_loading(self, layer) -> None: getattr(layer, "input_scale", None), ) weight = weight.t() - elif self.strategy == QuantizationStrategy.CHANNEL: weight, weight_scale, input_scale = process_fp8_weight_channel_strategy( layer.weight, layer.weight_scale, getattr(layer, "input_scale", None) @@ -174,7 +185,6 @@ def process_weights_after_loading(self, layer) -> None: layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) else: layer.input_scale = None - if self.strategy == QuantizationStrategy.BLOCK: maybe_post_process_fp8_weight_block(layer) @@ -193,11 +203,4 @@ def apply_weights( bias=bias, ) - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=layer.input_scale, - bias=bias, - ) + return self.fp8_linear.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 6fd0a6a1c822..652feb196457 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -11,8 +11,7 @@ CompressedTensorsScheme, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - ScaledMMLinearLayerConfig, - choose_scaled_mm_linear_kernel, + init_int8_linear_kernel, ) from vllm.model_executor.parameter import ( BasevLLMParameter, @@ -25,8 +24,6 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): - _kernel_backends_being_used: set[str] = set() - def __init__( self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool ): @@ -50,18 +47,13 @@ def create_weights( ): layer.logical_widths = output_partition_sizes - scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( + self.kernel = init_int8_linear_kernel( is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL), is_static_input_scheme=self.is_static_input_scheme, input_symmetric=self.input_symmetric, + module_name=self.__class__.__name__, ) - kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config) - - if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for CompressedTensorsW8A8Int8", kernel_type.__name__) - self._kernel_backends_being_used.add(kernel_type.__name__) - # WEIGHT weight = ModelWeightParameter( data=torch.empty( @@ -90,12 +82,12 @@ def create_weights( layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE + input_zero_point = None + input_scale = None if self.is_static_input_scheme: input_scale = BasevLLMParameter( data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader ) - layer.register_parameter("input_scale", input_scale) - if not self.input_symmetric: # Note: compressed-tensors stores the zp using the same dtype # as the weights @@ -103,16 +95,11 @@ def create_weights( input_zero_point = BasevLLMParameter( data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader ) - layer.register_parameter("input_zero_point", input_zero_point) - - self.kernel = kernel_type( - c=scaled_mm_linear_kernel_config, - w_q_param_name="weight", - w_s_param_name="weight_scale", - i_s_param_name="input_scale", - i_zp_param_name="input_zero_point", - azp_adj_param_name="azp_adj", - ) + + layer.register_parameter("input_zero_point", input_zero_point) + layer.register_parameter("input_scale", input_scale) + if not hasattr(layer, "azp_adj"): + layer.register_parameter("azp_adj", None) # Checkpoints are serialized in compressed-tensors format, which is # different from the format the kernel may want. Handle repacking here. diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 6ba18e59e4d5..45d2e4e33819 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -18,17 +18,19 @@ QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, is_layer_skipped, + kFp8DynamicTokenSym, + kFp8StaticTokenSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, - maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, ) from vllm.model_executor.parameter import ( @@ -91,10 +93,13 @@ def get_quant_method( class FBGEMMFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: FBGEMMFp8Config): self.quant_config = quant_config - self.fp8_linear = Fp8LinearOp( - act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN - ) self.out_dtype = torch.get_default_dtype() + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=kFp8DynamicTokenSym, + weight_quant_key=kFp8StaticTokenSym, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, + ) def create_weights( self, @@ -106,7 +111,6 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): - maybe_create_device_identity() weight_loader = extra_weight_attrs.get("weight_loader") del input_size, output_size output_size_per_partition = sum(output_partition_sizes) @@ -184,12 +188,4 @@ def apply( bias=bias, ) - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=None, - input_scale_ub=layer.input_scale_ub, - bias=bias, - ) + return self.fp8_linear.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 14ed28630680..6636c05cd197 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -48,6 +48,9 @@ QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_fi_trtllm_fp8_per_tensor_moe, @@ -76,12 +79,13 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, is_layer_skipped, + kFp8DynamicTensorSym, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, cutlass_block_fp8_supported, cutlass_fp8_supported, - maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, ) from vllm.model_executor.parameter import ( @@ -328,28 +332,30 @@ def __init__(self, quant_config: Fp8Config): self.weight_block_size = self.quant_config.weight_block_size self.block_quant = self.weight_block_size is not None self.act_q_static = self.quant_config.activation_scheme == "static" - if self.weight_block_size: - self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) - else: - # Use per-token quantization for better perf if dynamic and cutlass - if not self.act_q_static and cutlass_fp8_supported(): - self.act_q_group_shape = GroupShape.PER_TOKEN - else: - self.act_q_group_shape = GroupShape.PER_TENSOR if self.block_quant: assert not self.act_q_static assert self.weight_block_size is not None self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( weight_group_shape=GroupShape(*self.weight_block_size), - act_quant_group_shape=self.act_q_group_shape, + act_quant_group_shape=GroupShape(1, self.weight_block_size[0]), cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) else: - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.act_q_static, - act_quant_group_shape=self.act_q_group_shape, + # Use per-token quantization for better perf if dynamic and cutlass + if self.act_q_static: + activation_quant_key = kFp8StaticTensorSym + elif cutlass_fp8_supported(): + activation_quant_key = kFp8DynamicTokenSym + else: + activation_quant_key = kFp8DynamicTensorSym + + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=activation_quant_key, + weight_quant_key=kFp8StaticTensorSym, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, ) def create_weights( @@ -362,8 +368,6 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): - maybe_create_device_identity() - output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes @@ -462,8 +466,6 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): scale = create_fp8_input_scale(output_partition_sizes, weight_loader) set_weight_attrs(scale, {"scale_type": "input_scale"}) layer.register_parameter("input_scale", scale) - else: - layer.register_parameter("input_scale", None) def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_already_called_process_weights_after_loading", False): @@ -602,14 +604,7 @@ def apply( bias=bias, ) - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=layer.input_scale, - bias=bias, - ) + return self.fp8_linear.apply_weights(layer, x, bias) class Fp8MoEMethod(FusedMoEMethodBase): diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index 7be220f7a373..5f61cfcbf9c3 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -2,19 +2,58 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from collections.abc import Sequence from dataclasses import dataclass +from typing import Generic, TypeVar import torch +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, +) +from vllm.platforms import current_platform + @dataclass class ScaledMMLinearLayerConfig: - is_channelwise: bool + pass + + +@dataclass +class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): + # TODO: Chnage to QuantKey like FP8ScaledMMLinearLayerConfig is_static_input_scheme: bool + is_channelwise: bool input_symmetric: bool -class ScaledMMLinearKernel(ABC): +@dataclass +class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): + weight_quant_key: QuantKey + activation_quant_key: QuantKey + out_dtype: torch.dtype | None + + +_FP8ParamsT = tuple[ + torch.Tensor, # weight + torch.Tensor, # weight_scale + torch.Tensor | None, # input_scale, + torch.Tensor | None, # input_scale_ub, +] +_Int8ParamsT = tuple[ + torch.Tensor, # weight + torch.Tensor, # weight_scale + torch.Tensor | None, # input_scale, + torch.Tensor | None, # input_zp + torch.Tensor | None, # azp_adj +] + +_ParamsT = TypeVar("_ParamsT", _Int8ParamsT, _FP8ParamsT) +_ConfigT = TypeVar("_ConfigT", bound=ScaledMMLinearLayerConfig) + + +class ScaledMMLinearKernel(Generic[_ConfigT, _ParamsT], ABC): @classmethod @abstractmethod def is_supported( @@ -24,26 +63,14 @@ def is_supported( @classmethod @abstractmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: _ConfigT) -> tuple[bool, str | None]: raise NotImplementedError - def __init__( - self, - c: ScaledMMLinearLayerConfig, - w_q_param_name: str, - w_s_param_name: str, - i_s_param_name: str, - i_zp_param_name: str, - azp_adj_param_name: str, - ) -> None: - assert self.can_implement(c) - assert self.is_supported() + def __init__(self, c: _ConfigT, layer_param_names: Sequence[str]) -> None: + assert self.can_implement(c)[0] + assert self.is_supported()[0] self.config = c - self.w_q_name = w_q_param_name - self.w_s_name = w_s_param_name - self.i_s_name = i_s_param_name - self.i_zp_name = i_zp_param_name - self.azp_adj_name = azp_adj_param_name + self.layer_param_names = layer_param_names @abstractmethod def process_weights_after_loading(self, layer: torch.nn.Module) -> None: @@ -58,19 +85,103 @@ def apply_weights( ) -> torch.Tensor: raise NotImplementedError - def _get_weight_params( - self, layer: torch.nn.Module - ) -> tuple[ - torch.Tensor, # weight - torch.Tensor, # weight_scale - torch.Tensor | None, # input_scale, - torch.Tensor | None, # input_zp - torch.Tensor | None, # azp_adj - ]: + # return a covariant type in the subclass + @abstractmethod + def _get_layer_params(self, layer) -> _ParamsT: + raise NotImplementedError + + +class FP8ScaledMMLinearKernel( + ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, _FP8ParamsT], ABC +): + def __init__( + self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str] + ) -> None: + act_scale_descriptor = c.activation_quant_key.scale + self.quant_fp8 = QuantFP8( + static=act_scale_descriptor.static, + group_shape=act_scale_descriptor.group_shape, + num_token_padding=self.get_output_padding(), + ) + self.fp8_dtype = current_platform.fp8_dtype() + super().__init__(c, layer_param_names) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + pass + + def _get_layer_params(self, layer) -> _FP8ParamsT: + w, w_s, x_s, x_s_ub = self.layer_param_names + return ( + getattr(layer, w), + getattr(layer, w_s), + getattr(layer, x_s, None), + getattr(layer, x_s_ub, None), + ) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + fp8_dtype = self.fp8_dtype + maybe_out_dtype = self.config.out_dtype + w, w_s, x_s, x_s_ub = self._get_layer_params(layer) + + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_s computed from x. + # If static, layer.input_scale is scalar and x_s is input_scale. + # View input as 2D matrix for fp8 methods + x_2d = x.view(-1, x.shape[-1]) + output_shape = [*x.shape[:-1], w.shape[1]] + out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype + + # If input not quantized + # TODO(luka) remove this path if not used anymore + x_2d_q = x_2d + if x.dtype != fp8_dtype: + x_2d_q, x_s = self.quant_fp8( + x_2d, + x_s, + x_s_ub, + ) + return self.apply_scaled_mm( + A=x_2d_q, + B=w, + out_dtype=out_dtype, + As=x_s, + Bs=w_s, + bias=bias, + output_shape=output_shape, + ) + + @abstractmethod + def apply_scaled_mm( + self, + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None, + output_shape: list, + ) -> torch.Tensor: + raise NotImplementedError + + def get_output_padding(self) -> int | None: + return None + + +class Int8ScaledMMLinearKernel( + ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, _Int8ParamsT], ABC +): + def _get_layer_params(self, layer) -> _Int8ParamsT: + w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names return ( - getattr(layer, self.w_q_name), - getattr(layer, self.w_s_name), - getattr(layer, self.i_s_name), - getattr(layer, self.i_zp_name), - getattr(layer, self.azp_adj_name), + getattr(layer, w_q), + getattr(layer, w_s), + getattr(layer, i_s, None), + getattr(layer, i_zp, None), + getattr(layer, azp_adj, None), ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 4ccc4182367a..f79985c3a145 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -2,76 +2,229 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os +from typing import TypeVar +import torch + +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( - AiterScaledMMLinearKernel, + AiterInt8ScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import ( - CPUScaledMMLinearKernel, + CPUInt8ScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( - CutlassScaledMMLinearKernel, + CutlassFP8ScaledMMLinearKernel, + CutlassInt8ScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import ( + FlashInferFP8ScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import ( + ChannelWiseTorchFP8ScaledMMLinearKernel, + PerTensorTorchFP8ScaledMMLinearKernel, + RowWiseTorchFP8ScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import ( + ROCmFP8ScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + FP8ScaledMMLinearKernel, + FP8ScaledMMLinearLayerConfig, + Int8ScaledMMLinearKernel, + Int8ScaledMMLinearLayerConfig, ScaledMMLinearKernel, ScaledMMLinearLayerConfig, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( - TritonScaledMMLinearKernel, + TritonInt8ScaledMMLinearKernel, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.platforms import PlatformEnum, current_platform +logger = init_logger(__name__) + +# in priority/performance order (when available) +_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[Int8ScaledMMLinearKernel]]] = { + PlatformEnum.CPU: [CPUInt8ScaledMMLinearKernel], + PlatformEnum.CUDA: [ + CutlassInt8ScaledMMLinearKernel, + TritonInt8ScaledMMLinearKernel, + ], + PlatformEnum.ROCM: [AiterInt8ScaledMMLinearKernel, TritonInt8ScaledMMLinearKernel], +} + # in priority/performance order (when available) -_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { - PlatformEnum.CPU: [CPUScaledMMLinearKernel], - PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel], - PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], +_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = { + PlatformEnum.CUDA: [ + FlashInferFP8ScaledMMLinearKernel, + CutlassFP8ScaledMMLinearKernel, + PerTensorTorchFP8ScaledMMLinearKernel, + ChannelWiseTorchFP8ScaledMMLinearKernel, + ], + PlatformEnum.ROCM: [ + ROCmFP8ScaledMMLinearKernel, + PerTensorTorchFP8ScaledMMLinearKernel, + RowWiseTorchFP8ScaledMMLinearKernel, + ChannelWiseTorchFP8ScaledMMLinearKernel, + ], + PlatformEnum.CPU: [ + PerTensorTorchFP8ScaledMMLinearKernel, + ChannelWiseTorchFP8ScaledMMLinearKernel, + ], } +_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel) +_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig) + + +def is_supported_and_can_implement_kernel( + kernel: type[_KernelT], config: _KernelConfigT, compute_capability: int | None +) -> tuple[bool, str]: + # TODO: Fetch `VLLM_DISABLED_KERNELS` from vllm.envs instead. + if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","): + return False, f" {kernel.__name__} is disabled by environment variable" + + if compute_capability is None: + _cc = current_platform.get_device_capability() + if _cc is not None: + compute_capability = _cc[0] * 10 + _cc[1] + + is_supported, failure_reason = kernel.is_supported(compute_capability) + if not is_supported: + return False, f"{kernel.__name__} {failure_reason}." + + can_implement, failure_reason = kernel.can_implement(config) + if not can_implement: + return ( + False, + f"{kernel.__name__} {failure_reason}.", + ) + + return True, "" + def choose_scaled_mm_linear_kernel( - config: ScaledMMLinearLayerConfig, compute_capability: int | None = None -) -> type[ScaledMMLinearKernel]: + config: _KernelConfigT, + possible_kernels: dict[PlatformEnum, list[type[_KernelT]]], + compute_capability: int | None = None, + force_kernel: type[_KernelT] | None = None, +) -> type[_KernelT]: """ - Choose an ScaledMMLinearKernel that can implement the given config for the + Choose a _KernelT that can implement the given config for the given compute capability. Attempts to choose the best kernel in terms of performance. Args: - config (ScaledMMLinearLayerConfig): Description of the linear layer + config (_KernelConfigT): Description of the linear layer to be implemented. + possible_kernels (dict[PlatformEnum, list[_KernelT]]): A + dictionary of platforms and their list list of possible kernels. compute_capability (Optional[int], optional): The compute capability of the target device, if None uses `current_platform` to get the compute capability. Defaults to None. + force_kernel (Optional[type[_KernelT]]): An Optional forced kernel to override + the possible_kernels if it can be implemented. If None, it will only try the + possible kernels. Raises: ValueError: If no kernel can implement the given config. Returns: - type[ScaledMMLinearKernel]: Chosen kernel. + _KernelT: Chosen kernel. """ - failure_reasons = [] - for kernel in _POSSIBLE_KERNELS[current_platform._enum]: - if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","): - failure_reasons.append(f"{kernel.__name__}: disabled by env var") - continue + failure_reason_list = [] - # If the current platform uses compute_capability, - # make sure the kernel supports the compute capability. - is_supported, reason = kernel.is_supported(compute_capability) - if not is_supported: - failure_reasons.append(f"{kernel.__name__}: {reason}") - continue + if force_kernel is not None: + can_implement, failure_reason = is_supported_and_can_implement_kernel( + force_kernel, config, compute_capability + ) + if can_implement: + return force_kernel - can_implement, reason = kernel.can_implement(config) - if not can_implement: - failure_reasons.append(f"{kernel.__name__}: {reason}") - continue + logger.info_once( + "Tried to force %s, but the kernel couldn't be implemented", + force_kernel.__name__, + scope="global", + ) - return kernel + for kernel in possible_kernels[current_platform._enum]: + is_supported_and_can_implement, failure_reason = ( + is_supported_and_can_implement_kernel(kernel, config, compute_capability) + ) + if is_supported_and_can_implement: + return kernel + failure_reason_list.append(failure_reason) raise ValueError( "Failed to find a kernel that can implement the " - "ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reasons) + "ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reason_list) + ) + + +def init_fp8_linear_kernel( + activation_quant_key: QuantKey, + weight_quant_key: QuantKey, + out_dtype: torch.dtype, + force_kernel: type[FP8ScaledMMLinearKernel] | None = None, + module_name: str | None = None, +) -> FP8ScaledMMLinearKernel: + scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( + weight_quant_key=weight_quant_key, + activation_quant_key=activation_quant_key, + out_dtype=out_dtype, + ) + + kernel_type = choose_scaled_mm_linear_kernel( + scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, force_kernel=force_kernel + ) + + if module_name: + logger.info_once( + "Selected %s for %s", + kernel_type.__name__, + module_name, + scope="global", + ) + + return kernel_type( + scaled_mm_linear_kernel_config, + layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"], + ) + + +def init_int8_linear_kernel( + is_channelwise: bool, + is_static_input_scheme: bool, + input_symmetric: bool, + module_name: str, +) -> Int8ScaledMMLinearKernel: + config = Int8ScaledMMLinearLayerConfig( + is_channelwise=is_channelwise, + is_static_input_scheme=is_static_input_scheme, + input_symmetric=input_symmetric, + ) + + kernel_type = choose_scaled_mm_linear_kernel( + config, + _POSSIBLE_INT8_KERNELS, + ) + + logger.info_once( + "Selected %s for %s", + kernel_type.__name__, + module_name, + scope="global", + ) + + return kernel_type( + config, + layer_param_names=[ + "weight", + "weight_scale", + "input_scale", + "input_zero_point", + "azp_adj", + ], ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index 971bd2005a23..f806da9c1c7c 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -8,60 +8,41 @@ from vllm._aiter_ops import rocm_aiter_ops from vllm.platforms import current_platform -from .cutlass import CutlassScaledMMLinearKernel -from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig +from .cutlass import CutlassInt8ScaledMMLinearKernel +from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig -class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): +class AiterInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel): @classmethod def is_supported( cls, compute_capability: int | None = None ) -> tuple[bool, str | None]: if not current_platform.is_rocm(): - return ( - False, - "AiterScaledMMLinearKernel requires `aiter` which is not " - + "currently supported on non-ROCm platform.", - ) - if compute_capability is None: - _cc = current_platform.get_device_capability() - if _cc is not None: - compute_capability = _cc.major * 10 + _cc.minor + return False, "Requires ROCm." + if compute_capability is not None and compute_capability < 90: - return False, f"requires capability 90, got {compute_capability}" + return False, "requires compute capability 90 and above." try: import aiter # noqa: F401 # deliberately attempt to import aiter except Exception: - return ( - False, - "AiterScaledMMLinearKernel requires `aiter` which is not " - + "installed on ROCm.", - ) + return False, "requires `aiter` to be installed." if not rocm_aiter_ops.is_linear_enabled(): return ( False, - "AiterScaledMMLinearKernel is disabled. " - + "Enable by setting `VLLM_ROCM_USE_AITER=1` " + "requires setting `VLLM_ROCM_USE_AITER=1` " + "and `VLLM_ROCM_USE_AITER_LINEAR=1`. " + "`VLLM_ROCM_USE_AITER_LINEAR` default is True.", ) - return True, None @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if not c.input_symmetric: - return ( - False, - "AiterScaledMMLinearKernel only supports symmetric " + "quantization.", - ) + return False, "supports symmetric quantization only." return True, None - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - super().process_weights_after_loading(layer) - def apply_weights( self, layer: torch.nn.Module, @@ -69,28 +50,28 @@ def apply_weights( bias: torch.Tensor | None = None, ) -> torch.Tensor: """ - `AiterScaledMMLinearKernel` implements a fused version of + `AiterInt8ScaledMMLinearKernel` implements a fused version of `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)` where scale_a * a and scale_b * b are implemented using numpy-style broadcasting. Currently only support per-tensor-per-tensor GEMM and per-token-per-channel GEMM through AITER - w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support + w8a8 scaled gemm. `AiterInt8ScaledMMLinearKernel` also does not support ATIER block scaled GEMM and mix-precision GEMM. """ - w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer) # ops.scaled_int8_quant supports both dynamic and static quant: # * dynamic, i_s is None and x_s computed from x. # * static, i_s is scalar and x_s is i_s. symmetric = azp_adj is None assert symmetric, ( - "AiterScaledMMLinearKernel only supports symmetric quantization." + "AiterInt8ScaledMMLinearKernel only supports symmetric quantization." ) x_q, x_s, x_zp = ops.scaled_int8_quant(x, i_s, i_zp, symmetric=symmetric) assert x_zp is None, ( - "AiterScaledMMLinearKernel only supports symmetric quantization." + "AiterInt8ScaledMMLinearKernel only supports symmetric quantization." ) out_dtype = x.dtype @@ -117,12 +98,12 @@ def apply_weights( ), ( "Currently only support per-tensor-per-tensor GEMM " + " and per-token-per-channel GEMM through AITER" - " w8a8 scaled gemm. `AiterScaledMMLinearKernel` " + " w8a8 scaled gemm. `AiterInt8ScaledMMLinearKernel` " + "does not support AITER block scaled GEMM." ) # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects # a to be [M, K] # b to be [N, K] - # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format + # CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py index 6401b94d6278..b82f5781c289 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py @@ -14,24 +14,28 @@ from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import ( + Int8ScaledMMLinearKernel, + Int8ScaledMMLinearLayerConfig, +) -class CPUScaledMMLinearKernel(ScaledMMLinearKernel): +class CPUInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel): @classmethod def is_supported( cls, compute_capability: int | None = None ) -> tuple[bool, str | None]: if not current_platform.is_cpu(): - return False, "Requires CPU." + return False, "requires CPU." return True, None @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - weight = getattr(layer, self.w_q_name) + w_q_name, _, _, _, _ = self.layer_param_names + weight = getattr(layer, w_q_name) dtype = weight.dtype N, K = weight.size() if ( @@ -49,10 +53,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: # WEIGHT # Transpose to [K, N] for convenience - weight = getattr(layer, self.w_q_name) + w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names + weight = getattr(layer, w_q_name) replace_parameter( layer, - self.w_q_name, + w_q_name, torch.nn.Parameter(weight.t().data, requires_grad=False), ) @@ -61,28 +66,27 @@ def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: # If we have a fused module (QKV, MLP) with per tensor scales (thus N # scales being passed to the kernel), convert to the per-channel case. is_fused_module = len(layer.logical_widths) > 1 - weight_scale = getattr(layer, self.w_s_name) + weight_scale = getattr(layer, w_s_name) if is_fused_module and not self.config.is_channelwise: weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) replace_parameter( layer, - self.w_s_name, + w_s_name, torch.nn.Parameter(weight_scale.data, requires_grad=False), ) # INPUT SCALE if self.config.is_static_input_scheme: - input_scale = getattr(layer, self.i_s_name) + input_scale = getattr(layer, i_s_name) if self.config.input_symmetric: replace_parameter( layer, - self.i_s_name, + i_s_name, torch.nn.Parameter(input_scale.max(), requires_grad=False), ) - setattr(layer, self.i_zp_name, None) else: - input_zero_point = getattr(layer, self.i_zp_name) + input_zero_point = getattr(layer, i_zp_name) # reconstruct the ranges int8_traits = torch.iinfo(torch.int8) @@ -92,20 +96,16 @@ def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: scale = (range_max - range_min) / (int8_traits.max - int8_traits.min) replace_parameter( - layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False) + layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False) ) azp = ( (int8_traits.min - range_min / scale).round().to(dtype=torch.int32) ) replace_parameter( - layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False) + layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False) ) - else: - setattr(layer, self.i_s_name, None) - setattr(layer, self.i_zp_name, None) - # Different from cutlass, oneDNN kernels only need the AZP adjustment # term for dynamic quantization. And s_b should be folded into the # term. Such as: @@ -113,38 +113,37 @@ def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: # s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias = # s_a * GEMM_output - s_a * zp_a * adj + bias if not (self.config.input_symmetric and self.config.is_static_input_scheme): - weight = getattr(layer, self.w_q_name) - weight_scale = getattr(layer, self.w_s_name) + weight = getattr(layer, w_q_name) + weight_scale = getattr(layer, w_s_name) azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32) azp_adj = azp_adj * weight_scale.squeeze() setattr( layer, - self.azp_adj_name, + azp_adj_name, torch.nn.Parameter(azp_adj, requires_grad=False), ) - else: - setattr(layer, self.azp_adj_name, None) - weight = getattr(layer, self.w_q_name) + weight = getattr(layer, w_q_name) self.dnnl_handler = ops.create_onednn_scaled_mm( weight, - getattr(layer, self.w_s_name), + getattr(layer, w_s_name), torch.get_default_dtype(), - getattr(layer, self.i_s_name) is None, + getattr(layer, i_s_name) is None, not self.config.input_symmetric, 32, ) # weight is prepacked and maintained by the dnnl_handler, # release the original weight - setattr(layer, self.w_q_name, None) + setattr(layer, w_q_name, None) del weight def process_weights_for_sgl(self, layer: torch.nn.Module) -> None: + w_q_name, w_s_name, _, _, _ = self.layer_param_names # WEIGHT - weight = getattr(layer, self.w_q_name) + weight = getattr(layer, w_q_name) packed_weight = torch.ops._C.convert_weight_packed(weight) replace_parameter( - layer, self.w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False) + layer, w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False) ) if layer.bias is not None: @@ -156,19 +155,15 @@ def process_weights_for_sgl(self, layer: torch.nn.Module) -> None: # WEIGHT SCALE # CPU SGL kernels only support per-channel. # For per-tensor quant, convert to the per-channel case. - weight_scale = getattr(layer, self.w_s_name) + weight_scale = getattr(layer, w_s_name) if not self.config.is_channelwise: weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) replace_parameter( layer, - self.w_s_name, + w_s_name, torch.nn.Parameter(weight_scale.data, requires_grad=False), ) - setattr(layer, self.i_s_name, None) - setattr(layer, self.i_zp_name, None) - setattr(layer, self.azp_adj_name, None) - def apply_weights( self, layer: torch.nn.Module, @@ -187,7 +182,7 @@ def _apply_weights_onednn( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer) # ops.scaled_int8_quant supports both dynamic and static quant: # * dynamic, i_s is None and x_s computed from x. @@ -209,7 +204,7 @@ def _apply_weights_sgl( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - w_q, w_s, _, _, _ = self._get_weight_params(layer) + w_q, w_s, _, _, _ = self._get_layer_params(layer) return torch.ops._C.int8_scaled_mm_with_quant( x, w_q, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 2f00e0df8ed4..bcaf57bcbb26 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -11,35 +11,36 @@ ) from vllm.platforms import current_platform -from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import ( + FP8ScaledMMLinearKernel, + FP8ScaledMMLinearLayerConfig, + Int8ScaledMMLinearKernel, + Int8ScaledMMLinearLayerConfig, +) -class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): +class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel): @classmethod def is_supported( cls, compute_capability: int | None = None ) -> tuple[bool, str | None]: if not current_platform.is_cuda(): - return False, "Requires CUDA." - if compute_capability is None: - _cc = current_platform.get_device_capability() - if _cc is not None: - compute_capability = _cc.major * 10 + _cc.minor - if compute_capability is not None and compute_capability < 75: - return False, f"requires capability 75, got {compute_capability}" + return False, "requires CUDA." return True, None @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names + config = self.config # WEIGHT # Cutlass kernels need transposed weight. - weight = getattr(layer, self.w_q_name) + weight = getattr(layer, w_q_name) replace_parameter( layer, - self.w_q_name, + w_q_name, torch.nn.Parameter(weight.t().data, requires_grad=False), ) @@ -48,28 +49,28 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # If we have a fused module (QKV, MLP) with per tensor scales (thus N # scales being passed to the kernel), convert to the per-channel case. is_fused_module = len(layer.logical_widths) > 1 - weight_scale = getattr(layer, self.w_s_name) - if is_fused_module and not self.config.is_channelwise: + weight_scale = getattr(layer, w_s_name) + if is_fused_module and not config.is_channelwise: weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) replace_parameter( layer, - self.w_s_name, + w_s_name, torch.nn.Parameter(weight_scale.data, requires_grad=False), ) # INPUT SCALE - if self.config.is_static_input_scheme: - input_scale = getattr(layer, self.i_s_name) + if config.is_static_input_scheme: + input_scale = getattr(layer, i_s_name) - if self.config.input_symmetric: + if config.input_symmetric: replace_parameter( layer, - self.i_s_name, + i_s_name, torch.nn.Parameter(input_scale.max(), requires_grad=False), ) - setattr(layer, self.i_zp_name, None) + setattr(layer, i_zp_name, None) else: - input_zero_point = getattr(layer, self.i_zp_name) + input_zero_point = getattr(layer, i_zp_name) # reconstruct the ranges int8_traits = torch.iinfo(torch.int8) @@ -79,38 +80,32 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: scale = (range_max - range_min) / (int8_traits.max - int8_traits.min) replace_parameter( - layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False) + layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False) ) # AZP loaded as int8 but used as int32 azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32) replace_parameter( - layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False) + layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False) ) - else: - setattr(layer, self.i_s_name, None) - setattr(layer, self.i_zp_name, None) - # azp_adj is the AZP adjustment term, used to account for weights. # It does not depend on scales or azp, so it is the same for # static and dynamic quantization. # For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md # https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md - if not self.config.input_symmetric: - weight = getattr(layer, self.w_q_name) + if not config.input_symmetric: + weight = getattr(layer, w_q_name) azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32) - if self.config.is_static_input_scheme: + if config.is_static_input_scheme: # cutlass_w8a8 requires azp to be folded into azp_adj # in the per-tensor case - azp_adj = getattr(layer, self.i_zp_name) * azp_adj + azp_adj = getattr(layer, i_zp_name) * azp_adj setattr( layer, - self.azp_adj_name, + azp_adj_name, torch.nn.Parameter(azp_adj, requires_grad=False), ) - else: - setattr(layer, self.azp_adj_name, None) def apply_weights( self, @@ -118,7 +113,7 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer) # ops.scaled_int8_quant supports both dynamic and static quant: # * dynamic, i_s is None and x_s computed from x. @@ -145,3 +140,34 @@ def apply_weights( return ops.cutlass_scaled_mm( x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias ) + + +class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): + @classmethod + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + if not current_platform.is_cuda(): + return False, "requires CUDA." + return True, None + + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + return True, None + + def apply_scaled_mm( + self, + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None, + output_shape: list, + ) -> torch.Tensor: + # Fused GEMM_DQ + output = ops.cutlass_scaled_mm( + A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias + ) + return output.view(*output_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flashinfer.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flashinfer.py new file mode 100644 index 000000000000..991cda862acf --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flashinfer.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import torch + +from vllm.platforms import current_platform +from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer + +from .ScaledMMLinearKernel import ( + FP8ScaledMMLinearKernel, + FP8ScaledMMLinearLayerConfig, +) + + +class FlashInferFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): + @classmethod + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + if not current_platform.is_cuda(): + return False, "requires CUDA." + + if not has_flashinfer(): + return False, "requires FlashInfer to be installed." + + if compute_capability is not None and compute_capability < 100: + return False, "requires compute capability 100 and above." + + return True, None + + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + per_tensor_activation_scales = ( + c.activation_quant_key.scale.group_shape.is_per_tensor() + ) + per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() + + if not (per_tensor_activation_scales and per_tensor_weight_scales): + return False, "requires per tensor activation and weight scales." + + return True, None + + def apply_scaled_mm( + self, + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None, + output_shape: list, + ) -> torch.Tensor: + return flashinfer_scaled_fp8_mm( + A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py new file mode 100644 index 000000000000..60e71815446f --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py @@ -0,0 +1,221 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import torch +from packaging import version + +from vllm.config import CompilationMode, get_current_vllm_config +from vllm.platforms import current_platform + +from .ScaledMMLinearKernel import ( + FP8ScaledMMLinearKernel, + FP8ScaledMMLinearLayerConfig, +) + + +class TorchFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): + """ + Base class for FP8 linear kernels using Torch. + Each subclass represents a kernel variant for + specific device capabilities and torch versions. + """ + + @classmethod + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + if not (current_platform.is_cuda_alike() or current_platform.is_cpu()): + return False, "requires ROCm, CUDA or CPU." + + if compute_capability is not None and compute_capability < 89: + return False, "requires compute capability 89 and above." + + return True, None + + def get_output_padding(self) -> int | None: + # Note: we pad the input because torch._scaled_mm is more performant + # for matrices with batch dimension > 16. + # This could change in the future. + # We also don't pad when using torch.compile, + # as it breaks with dynamic shapes. + # + # The perf gain is still relevant as of 16/1/2026 + # torch version == 2.9.0. More details in the link below: + # https://github.com/vllm-project/vllm/issues/32269 + vllm_config = get_current_vllm_config().compilation_config + pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE + return 17 if pad_output else None + + +class PerTensorTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel): + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + per_tensor_activation_scales = ( + c.activation_quant_key.scale.group_shape.is_per_tensor() + ) + per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() + + if not (per_tensor_activation_scales and per_tensor_weight_scales): + return False, "requires per tensor activation and weight scales." + return True, None + + def apply_scaled_mm( + self, + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None, + output_shape: list, + ) -> torch.Tensor: + output = torch._scaled_mm( + A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias + ) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + + return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape) + + +class RowWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel): + @classmethod + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + if not current_platform.is_rocm(): + return False, "requires ROCm." + + from vllm.platforms.rocm import on_mi3xx + + if not on_mi3xx(): + return False, "requires MI3xx." + + if compute_capability is not None and compute_capability < 94: + return False, "requires compute capability 94 and above." + + if not version.parse(torch.__version__) >= version.parse("2.7"): + return False, "requires pytorch version >=2.7." + + return True, None + + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + per_tensor_activation_scales = ( + c.activation_quant_key.scale.group_shape.is_per_tensor() + ) + per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() + + if c.out_dtype == torch.float16: + # hipblaslt rowwise _scaled_mm only supports BFloat16 + return False, "supports BFloat16 output data type only." + + if per_tensor_activation_scales or per_tensor_weight_scales: + return False, "cannot be used with per tensor activation and weight scales." + + return True, None + + def apply_scaled_mm( + self, + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None, + output_shape: list, + ) -> torch.Tensor: + # Note: + # For now it has only been validated on ROCm platform. + # fp8 rowwise scaling in torch._scaled_mm is introduced in + # https://github.com/pytorch/pytorch/pull/144432 using + # hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above. + # + # For CUDA platform please validate if the torch._scaled_mm supports + # rowwise scaled GEMM before using it + + # Fused GEMM_DQ Rowwise GEMM + output = torch._scaled_mm( + A, + B, + out_dtype=out_dtype, + scale_a=As, + scale_b=Bs.t(), + bias=bias, + ) + + return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape) + + +class ChannelWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel): + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + per_tensor_activation_scales = ( + c.activation_quant_key.scale.group_shape.is_per_tensor() + ) + per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() + + if per_tensor_activation_scales and per_tensor_weight_scales: + return False, "cannot be used with per tensor activation and weight scales." + + return True, None + + def apply_scaled_mm( + self, + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None, + output_shape: list, + ) -> torch.Tensor: + # Use unfused DQ due to limitations with scaled_mm + + # Symmetric quantized GEMM by definition computes the following: + # C = (s_x * X) (s_w * W) + bias + # This is equivalent to dequantizing the weights and activations + # before applying a GEMM. + # + # In order to compute quantized operands, a quantized kernel + # will rewrite the above like so: + # C = s_w * s_x * (X * W) + bias + # + # For the scaled_mm fallback case, we break this down, since it + # does not support s_w being a vector. + + # Input scaling factors are no longer optional in _scaled_mm starting + # from pytorch 2.5. Allocating a dummy tensor to pass as scales + dummy_tensor = torch.ones(1, dtype=torch.float32, device=A.device) + + # GEMM + # This computes C = (X * W). + # Output in fp32 to allow subsequent ops to happen in-place + output = torch._scaled_mm( + A, + B, + scale_a=dummy_tensor, + scale_b=dummy_tensor, + out_dtype=torch.float32, + ) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + + # Unpad (undo num_token_padding) + output = torch.narrow(output, 0, 0, output_shape[0]) + x_scale = torch.narrow(As, 0, 0, output_shape[0]) + + # DQ + # C = sw * sx * (X * W) + bias + output = output * x_scale * Bs.t() + if bias is not None: + output = output + bias + return output.to(out_dtype).view(*output_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py new file mode 100644 index 000000000000..e52015ba2d11 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.utils.platform_utils import get_cu_count +from vllm.utils.torch_utils import direct_register_custom_op + +from .ScaledMMLinearKernel import ( + FP8ScaledMMLinearKernel, + FP8ScaledMMLinearLayerConfig, +) + + +def rocm_per_tensor_float_w8a8_scaled_mm_impl( + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + if ( + A.shape[0] == 1 + and B.shape[1] % 16 == 0 + and ((bias is None) or (bias.dtype == out_dtype)) + ): + output = ops.wvSplitKQ( + B.t(), + A, + out_dtype, + As, + Bs, + get_cu_count(), + bias, + ) + # Fallback + else: + output = torch._scaled_mm( + A, + B, + out_dtype=out_dtype, + scale_a=As, + scale_b=Bs, + bias=bias, + ) + return output + + +def rocm_per_tensor_float_w8a8_scaled_mm_fake( + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + return A.new_empty((*A.shape[:-1], B.shape[1]), dtype=out_dtype) + + +if current_platform.is_rocm(): + direct_register_custom_op( + op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl", + op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl, + fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake, + ) + + +class ROCmFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): + @classmethod + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + if not current_platform.is_rocm(): + return False, "requires ROCm." + + from vllm.platforms.rocm import on_mi3xx + + if not on_mi3xx(): + return False, "requires MI3xx." + + if not envs.VLLM_ROCM_USE_SKINNY_GEMM: + return False, "requires VLLM_ROCM_USE_SKINNY_GEMM to be enabled." + + return True, None + + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + per_tensor_activation_scales = ( + c.activation_quant_key.scale.group_shape.is_per_tensor() + ) + per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() + + if not (per_tensor_activation_scales and per_tensor_weight_scales): + return False, "requires per tensor activation and weight scales." + + return True, None + + def apply_scaled_mm( + self, + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None, + output_shape: list, + ) -> torch.Tensor: + output = torch.ops.vllm.rocm_per_tensor_float_w8a8_scaled_mm_impl( + A, B, out_dtype, As, Bs, bias + ) + return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py index e4286f91bcbe..d2d90ed06a7a 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py @@ -14,30 +14,35 @@ ) from vllm.platforms import current_platform -from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig +from .cutlass import CutlassInt8ScaledMMLinearKernel +from .ScaledMMLinearKernel import ( + Int8ScaledMMLinearLayerConfig, +) -class TritonScaledMMLinearKernel(ScaledMMLinearKernel): +class TritonInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel): @classmethod def is_supported( cls, compute_capability: int | None = None ) -> tuple[bool, str | None]: if current_platform.is_cuda_alike(): return True, None - return False, "Requires ROCm or CUDA." + return False, "requires ROCm or CUDA." @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if not c.input_symmetric: - return False, "Only symmetric input is supported." + return False, "supports symmetric input only." return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - weight = getattr(layer, self.w_q_name) + w_q, _, i_s, _, _ = self._get_layer_params(layer) + w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names + replace_parameter( layer, - self.w_q_name, - torch.nn.Parameter(weight.t().data, requires_grad=False), + w_q_name, + torch.nn.Parameter(w_q.t().data, requires_grad=False), ) # WEIGHT SCALE @@ -45,29 +50,29 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # If we have a fused module (QKV, MLP) with per tensor scales (thus N # scales being passed to the kernel), convert to the per-channel case. is_fused_module = len(layer.logical_widths) > 1 - weight_scale = getattr(layer, self.w_s_name) + weight_scale = getattr(layer, w_s_name) if is_fused_module and not self.config.is_channelwise: weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) replace_parameter( layer, - self.w_s_name, + w_s_name, torch.nn.Parameter(weight_scale.data, requires_grad=False), ) # INPUT SCALE if self.config.is_static_input_scheme: - input_scale = getattr(layer, self.i_s_name) + assert i_s is not None replace_parameter( layer, - self.i_s_name, - torch.nn.Parameter(input_scale.max(), requires_grad=False), + i_s_name, + torch.nn.Parameter(i_s.max(), requires_grad=False), ) - setattr(layer, self.i_zp_name, None) + setattr(layer, i_zp_name, None) else: - setattr(layer, self.i_s_name, None) - setattr(layer, self.i_zp_name, None) + setattr(layer, i_s_name, None) + setattr(layer, i_zp_name, None) - setattr(layer, self.azp_adj_name, None) + setattr(layer, azp_adj_name, None) def apply_weights( self, @@ -75,7 +80,7 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + w_q, w_s, i_s, i_zp, _ = self._get_layer_params(layer) x_q, x_s, x_zp = ops.scaled_int8_quant( x.contiguous(), i_s, i_zp, symmetric=True diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index ed3f000d90ce..91dfa03b8b86 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -49,6 +49,9 @@ QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, @@ -78,10 +81,12 @@ GroupShape, cutlass_fp4_supported, is_layer_skipped, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, + kFp8StaticTokenSym, swizzle_blockscale, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, cutlass_block_fp8_supported, requantize_with_max_scale, ) @@ -438,8 +443,11 @@ class ModelOptFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config - self.fp8_linear = Fp8LinearOp( - act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=kFp8StaticTensorSym, + weight_quant_key=kFp8StaticTensorSym, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, ) def create_weights( @@ -507,13 +515,7 @@ def apply( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=layer.input_scale, - bias=bias, - ) + return self.fp8_linear.apply_weights(layer, x, bias) class ModelOptFp8PcPtLinearMethod(LinearMethodBase): @@ -527,8 +529,11 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config - self.fp8_linear = Fp8LinearOp( - act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=kFp8DynamicTokenSym, + weight_quant_key=kFp8StaticTokenSym, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, ) def create_weights( @@ -585,13 +590,7 @@ def apply( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=None, - bias=bias, - ) + return self.fp8_linear.apply_weights(layer, x, bias) class ModelOptFp8PbWoLinearMethod(LinearMethodBase): diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index 80efc29de67e..b97eddaff04a 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -17,11 +17,13 @@ Fp8KVCacheMethod, Fp8LinearMethod, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, is_layer_skipped, + kFp8DynamicTokenSym, ) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -97,9 +99,11 @@ def __init__(self, quant_config: PTPCFp8Config): ) super().__init__(quant_config=quant_config) # Force weight quantization - self.quant_config.is_checkpoint_fp8_serialized = False - self.fp8_linear = Fp8LinearOp( - act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=kFp8DynamicTokenSym, + weight_quant_key=kFp8DynamicTokenSym, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: @@ -130,11 +134,4 @@ def apply( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=None, - input_scale_ub=None, - bias=bias, - ) + return self.fp8_linear.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 1e5ee93b61f2..635b5cf894ef 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -7,10 +7,18 @@ import torch from torch.nn import Parameter +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, + kFp8StaticTokenSym, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale, ) @@ -23,6 +31,8 @@ __all__ = ["QuarkW8A8Fp8"] +logger = init_logger(__name__) + class QuarkW8A8Fp8(QuarkScheme): def __init__( @@ -35,15 +45,16 @@ def __init__( self.is_static_input_scheme = not cast(bool, input_config.get("is_dynamic")) self.input_qscheme = cast(str, input_config.get("qscheme")) - per_token = ( + per_token_activation = ( not self.is_static_input_scheme and self.input_qscheme == "per_channel" ) - self.act_quant_group_shape = ( - GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR + per_token_weight = self.weight_qscheme == "per_channel" + + self.activation_quant_key = ( + kFp8DynamicTokenSym if per_token_activation else kFp8StaticTensorSym ) - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.is_static_input_scheme, - act_quant_group_shape=self.act_quant_group_shape, + self.weight_quant_key = ( + kFp8StaticTokenSym if per_token_weight else kFp8StaticTensorSym ) self.out_dtype = torch.get_default_dtype() @@ -94,7 +105,7 @@ def process_weights_after_loading(self, layer) -> None: layer.input_scale = Parameter(input_scale, requires_grad=False) else: weight_scale = layer.weight_scale.data - if self.act_quant_group_shape == GroupShape.PER_TOKEN: + if self.activation_quant_key.scale.group_shape == GroupShape.PER_TOKEN: weight_scale = weight_scale.view(-1, 1) layer.weight = Parameter(weight.t(), requires_grad=False) # required by torch.compile to be torch.nn.Parameter @@ -106,8 +117,6 @@ def process_weights_after_loading(self, layer) -> None: # INPUT SCALE if self.is_static_input_scheme: layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) - else: - layer.input_scale = None def create_weights( self, @@ -163,17 +172,17 @@ def create_weights( input_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", input_scale) + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=self.activation_quant_key, + weight_quant_key=self.weight_quant_key, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, + ) + def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=layer.input_scale, - bias=bias, - ) + return self.fp8_linear.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py index 42d2ed2e85ed..a7a7726bae0e 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py @@ -7,8 +7,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - ScaledMMLinearLayerConfig, - choose_scaled_mm_linear_kernel, + init_int8_linear_kernel, ) from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.parameter import ( @@ -22,8 +21,6 @@ class QuarkW8A8Int8(QuarkScheme): - _kernel_backends_being_used: set[str] = set() - def __init__( self, qscheme: str, @@ -50,18 +47,13 @@ def create_weights( ): layer.logical_widths = output_partition_sizes - scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( + self.kernel = init_int8_linear_kernel( is_channelwise=(self.qscheme == "per_channel"), is_static_input_scheme=(self.is_static_input_scheme is True), input_symmetric=(self.input_symmetric is True), + module_name=self.__class__.__name__, ) - kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config) - - if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__) - self._kernel_backends_being_used.add(kernel_type.__name__) - # WEIGHT weight = ModelWeightParameter( data=torch.empty( @@ -102,25 +94,21 @@ def create_weights( layer.register_parameter("weight_zero_point", weight_zero_point) # INPUT SCALE + input_zero_point = None + input_scale = None if self.is_static_input_scheme: input_scale = BasevLLMParameter( data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader ) - layer.register_parameter("input_scale", input_scale) input_zero_point = BasevLLMParameter( data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader ) - layer.register_parameter("input_zero_point", input_zero_point) - - self.kernel = kernel_type( - c=scaled_mm_linear_kernel_config, - w_q_param_name="weight", - w_s_param_name="weight_scale", - i_s_param_name="input_scale", - i_zp_param_name="input_zero_point", - azp_adj_param_name="azp_adj", - ) + + layer.register_parameter("input_scale", input_scale) + layer.register_parameter("input_zero_point", input_zero_point) + if not hasattr(layer, "azp_adj"): + layer.register_parameter("azp_adj", None) # Checkpoints are serialized in quark format, which is # different from the format the kernel may want. Handle repacking here. diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 48beb977cdb1..91fc8760b5ef 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -123,6 +123,9 @@ def __str__(self): kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR) kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True) +kStaticTokenScale = ScaleDesc(torch.float32, True, GroupShape.PER_TOKEN) +kFp8StaticTokenSym = QuantKey(FP8_DTYPE, kStaticTokenScale, symmetric=True) + kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN) kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 4287922417c6..f949c0c076e7 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -1,34 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable import torch -from packaging import version from vllm import _custom_ops as ops -from vllm import envs -from vllm.config import CompilationMode, get_current_vllm_config -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform -from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer -from vllm.utils.platform_utils import get_cu_count -from vllm.utils.torch_utils import direct_register_custom_op - -# Input scaling factors are no longer optional in _scaled_mm starting -# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale -TORCH_DEVICE_IDENTITY = None - -# The condition to determine if it is on a platform that supports -# torch._scaled_mm rowwise feature. -# The condition is determined once as the operations -# are time-consuming. -USE_ROWWISE_TORCH_SCALED_MM = ( - current_platform.is_rocm() - and version.parse(torch.__version__) >= version.parse("2.7") - and current_platform.has_device_capability(94) -) def sparse_cutlass_supported() -> bool: @@ -140,361 +117,6 @@ def requantize_with_max_scale( return max_w_scale, weight -def maybe_create_device_identity(): - # Allocate dummy ones tensor for torch._scaled_mm - global TORCH_DEVICE_IDENTITY - if TORCH_DEVICE_IDENTITY is None: - TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) - - -def cutlass_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, - **kwargs, -) -> torch.Tensor: - # Fused GEMM_DQ - output = ops.cutlass_scaled_mm( - qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias - ) - return output.view(*output_shape) - - -def flashinfer_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, - **kwargs, -) -> torch.Tensor: - return flashinfer_scaled_fp8_mm( - qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias - ) - - -def rocm_per_tensor_w8a8_scaled_mm_impl( - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, -) -> torch.Tensor: - from vllm.platforms.rocm import on_mi3xx - - if ( - envs.VLLM_ROCM_USE_SKINNY_GEMM - and on_mi3xx() - and qinput.shape[0] == 1 - and qinput.shape[1] % 16 == 0 - and ((bias is None) or (bias.dtype == out_dtype)) - ): - output = ops.wvSplitKQ( - weight.t(), - qinput, - out_dtype, - scale_a, - scale_b, - get_cu_count(), - bias, - ) - else: - output = torch._scaled_mm( - qinput, - weight, - out_dtype=out_dtype, - scale_a=scale_a, - scale_b=scale_b, - bias=bias, - ) - return output - - -def rocm_per_tensor_w8a8_scaled_mm_fake( - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, -) -> torch.Tensor: - return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]), dtype=out_dtype) - - -def rocm_per_tensor_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, -) -> torch.Tensor: - output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl( - qinput, weight, out_dtype, scale_a, scale_b, bias - ) - return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape) - - -direct_register_custom_op( - op_name="rocm_per_tensor_w8a8_scaled_mm_impl", - op_func=rocm_per_tensor_w8a8_scaled_mm_impl, - fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake, -) - - -def torch_per_tensor_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, -) -> torch.Tensor: - output = torch._scaled_mm( - qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias - ) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] - - return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape) - - -def torch_per_token_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, - **kwargs, -) -> torch.Tensor: - # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM - # when using it. - # For now it has only been validated on ROCm platform. - # fp8 rowwise scaling in torch._scaled_mm is introduced in - # https://github.com/pytorch/pytorch/pull/144432 using - # hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above. - # - # For CUDA platform please validate if the torch._scaled_mm supports - # rowwise scaled GEMM before using it - - # Fused GEMM_DQ Rowwise GEMM - output = torch._scaled_mm( - qinput, - weight, - out_dtype=out_dtype, - scale_a=scale_a, - scale_b=scale_b.t(), - bias=bias, - ) - - output = torch.narrow(output, 0, 0, qinput.shape[0]) - output = output.view(*output_shape) - return output - - -def torch_channelwise_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, - **kwargs, -) -> torch.Tensor: - # Use unfused DQ due to limitations with scaled_mm - - # Symmetric quantized GEMM by definition computes the following: - # C = (s_x * X) (s_w * W) + bias - # This is equivalent to dequantizing the weights and activations - # before applying a GEMM. - # - # In order to compute quantized operands, a quantized kernel - # will rewrite the above like so: - # C = s_w * s_x * (X * W) + bias - # - # For the scaled_mm fallback case, we break this down, since it - # does not support s_w being a vector. - - # GEMM - # This computes C = (X * W). - # Output in fp32 to allow subsequent ops to happen in-place - output = torch._scaled_mm( - qinput, - weight, - scale_a=TORCH_DEVICE_IDENTITY, - scale_b=TORCH_DEVICE_IDENTITY, - out_dtype=torch.float32, - ) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] - # Unpad (undo num_token_padding) - output = torch.narrow(output, 0, 0, qinput.shape[0]) - x_scale = torch.narrow(scale_a, 0, 0, qinput.shape[0]) - - # DQ - # C = sw * sx * (X * W) + bias - output = output * x_scale * scale_b.t() - if bias is not None: - output = output + bias - return output.to(out_dtype).view(*output_shape) - - -def dispatch_w8a8_scaled_mm( - preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool -) -> Callable[..., torch.Tensor]: - if per_tensor_weights and per_tensor_activations: - if preferred_backend == "rocm": - return rocm_per_tensor_w8a8_scaled_mm - if preferred_backend == "flashinfer": - return flashinfer_w8a8_scaled_mm - if preferred_backend == "cutlass": - return cutlass_w8a8_scaled_mm - return torch_per_tensor_w8a8_scaled_mm - - # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A - if preferred_backend == "cutlass" or preferred_backend == "flashinfer": - return cutlass_w8a8_scaled_mm - - # If torch.scaled_mm supports per-channel (weights) per-token (inputs) - if ( - not per_tensor_weights - and not per_tensor_activations - and USE_ROWWISE_TORCH_SCALED_MM - ): - return torch_per_token_w8a8_scaled_mm - # Normally, torch.scaled_mm supports per tensor weights + activations only - # so fallback to naive if per channel or per token - return torch_channelwise_w8a8_scaled_mm - - -# TODO(luka): follow similar pattern for marlin and block-fp8-linear -# https://github.com/vllm-project/vllm/issues/14397 -class Fp8LinearOp: - """ - This class executes a FP8 linear layer using cutlass if supported and - torch.scaled_mm otherwise. - It needs to be a class instead of a method so that config can be read - in the __init__ method, as reading config is not allowed inside forward. - """ - - def __init__( - self, - act_quant_static: bool, - act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR, - pad_output: bool | None = None, - ): - if current_platform.is_rocm(): - self.preferred_backend = "rocm" - elif current_platform.is_cuda() and cutlass_fp8_supported(): - if has_flashinfer() and current_platform.has_device_capability(100): - self.preferred_backend = "flashinfer" - else: - self.preferred_backend = "cutlass" - else: - self.preferred_backend = "torch" - - # Note: we pad the input because torch._scaled_mm is more performant - # for matrices with batch dimension > 16. - # This could change in the future. - # We also don't pad when using torch.compile, - # as it breaks with dynamic shapes. - if pad_output is None: - config = get_current_vllm_config().compilation_config - pad_output = ( - config.mode < CompilationMode.VLLM_COMPILE - and self.preferred_backend == "torch" - ) - - self.output_padding = 17 if pad_output else None - self.act_quant_static = act_quant_static - self.act_quant_group_shape = act_quant_group_shape - self.quant_fp8 = QuantFP8( - static=act_quant_static, - group_shape=act_quant_group_shape, - num_token_padding=self.output_padding, - ) - - def apply( - self, - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - out_dtype: torch.dtype | None = None, - input_scale: torch.Tensor | None = None, - input_scale_ub: torch.Tensor | None = None, - bias: torch.Tensor | None = None, - ) -> torch.Tensor: - # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.input_scale is None and x_scale computed from x. - # If static, layer.input_scale is scalar and x_scale is input_scale. - - # View input as 2D matrix for fp8 methods - input_2d = input.view(-1, input.shape[-1]) - output_shape = [*input.shape[:-1], weight.shape[1]] - - if out_dtype is None: - out_dtype = input.dtype - - # If input not quantized - # TODO(luka) remove this path if not used anymore - if input.dtype != current_platform.fp8_dtype(): - qinput, x_scale = self.quant_fp8( - input_2d, - input_scale, - input_scale_ub, - ) - else: - qinput, x_scale = input_2d, input_scale - - # Must have dim() conditions - # In per-token quant scenario, when the number of token is 1, - # the scale will only have 1 elements. - # Without checking the dim(), - # we cannot distingushes between per-tensor and per-token quant. - # Example: - # When the number of token is 1, per-token scale is [[1]] - # When per-tensor scale is [1] or (). - per_tensor_weights = weight_scale.numel() == 1 - per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2 - - # TODO(luka) do this dispatch during init (after ScaledMM refactor) - w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm( - self.preferred_backend, per_tensor_weights, per_tensor_activations - ) - - return w8a8_scaled_mm_func( - qinput=qinput, - weight=weight, - out_dtype=out_dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias, - output_shape=output_shape, - ) - - def normalize_e4m3fn_to_e4m3fnuz( weight: torch.Tensor, weight_scale: torch.Tensor,