diff --git a/benchmarks/kernels/benchmark_block_fp8_gemm.py b/benchmarks/kernels/benchmark_block_fp8_gemm.py index 8d50c3828206..9eddc907b937 100644 --- a/benchmarks/kernels/benchmark_block_fp8_gemm.py +++ b/benchmarks/kernels/benchmark_block_fp8_gemm.py @@ -9,11 +9,12 @@ import torch from vllm.benchmarks.lib.utils import default_vllm_config -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - W8A8BlockFp8LinearOp, +from vllm.model_executor.kernels.linear import ( + init_fp8_linear_kernel, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, + create_fp8_quant_key, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED, @@ -70,11 +71,15 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass): weight_group_shape = GroupShape(block_n, block_k) act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization - linear_op = W8A8BlockFp8LinearOp( - weight_group_shape=weight_group_shape, - act_quant_group_shape=act_quant_group_shape, - cutlass_block_fp8_supported=use_cutlass, - use_aiter_and_is_supported=False, + linear_op = init_fp8_linear_kernel( + weight_quant_key=create_fp8_quant_key( + static=True, group_shape=weight_group_shape + ), + activation_quant_key=create_fp8_quant_key( + static=False, group_shape=act_quant_group_shape + ), + out_dtype=torch.get_default_dtype(), + module_name="build_w8a8_block_fp8_runner", ) def run(): diff --git a/tests/compile/passes/distributed/test_fusion_all_reduce.py b/tests/compile/passes/distributed/test_fusion_all_reduce.py index 92e7402c0537..a50d3ca8e3e1 100644 --- a/tests/compile/passes/distributed/test_fusion_all_reduce.py +++ b/tests/compile/passes/distributed/test_fusion_all_reduce.py @@ -39,7 +39,9 @@ class TestAllReduceRMSNormModel(torch.nn.Module): - def __init__(self, hidden_size=16, token_num=16, eps=1e-6): + def __init__( + self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16 + ): super().__init__() self.hidden_size = hidden_size self.eps = eps @@ -78,7 +80,9 @@ def ops_in_model_after(self): class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): quant_key = kFp8StaticTensorSym - def __init__(self, hidden_size=16, token_num=16, eps=1e-6): + def __init__( + self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16 + ): super().__init__() self.hidden_size = hidden_size self.eps = eps @@ -88,6 +92,7 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): weight_shape=(hidden_size, hidden_size), activation_quant_key=self.quant_key, weight_quant_key=self.quant_key, + input_dtype=dtype, ) for i in range(3) ] @@ -127,7 +132,9 @@ def ops_in_model_before(self): class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module): - def __init__(self, hidden_size=16, token_num=16, eps=1e-6): + def __init__( + self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16 + ): super().__init__() self.hidden_size = hidden_size self.eps = eps @@ -314,7 +321,7 @@ def all_reduce_fusion_pass_on_test_model( ) token_num = batch_size * seq_len - model = test_model_cls(hidden_size, token_num) + model = test_model_cls(hidden_size, token_num, dtype=dtype) hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) diff --git a/tests/compile/passes/distributed/test_sequence_parallelism.py b/tests/compile/passes/distributed/test_sequence_parallelism.py index 8588e0501783..667ef4e04fbe 100644 --- a/tests/compile/passes/distributed/test_sequence_parallelism.py +++ b/tests/compile/passes/distributed/test_sequence_parallelism.py @@ -109,6 +109,7 @@ def __init__(self, hidden_size=16, eps=1e-6): weight_shape=(hidden_size, hidden_size), activation_quant_key=self.quant_key, weight_quant_key=self.quant_key, + input_dtype=self.vllm_config.model_config.dtype, ) for i in range(3) ] diff --git a/tests/compile/passes/test_functionalization.py b/tests/compile/passes/test_functionalization.py index 8d13e622d81c..0e1d3b3a5d64 100644 --- a/tests/compile/passes/test_functionalization.py +++ b/tests/compile/passes/test_functionalization.py @@ -23,6 +23,7 @@ ModelConfig, PassConfig, VllmConfig, + get_current_vllm_config, set_current_vllm_config, ) from vllm.model_executor.layers.activation import SiluAndMul @@ -49,6 +50,7 @@ def __init__(self, hidden_size: int = 128): weight_shape=(hidden_size, hidden_size), activation_quant_key=self.quant_key, weight_quant_key=self.quant_key, + input_dtype=get_current_vllm_config().model_config.dtype, ) def forward(self, x): @@ -92,6 +94,7 @@ def __init__(self, hidden_size=16, intermediate_size=32): weight_shape=(hidden_size, intermediate_size), activation_quant_key=self.quant_key, weight_quant_key=self.quant_key, + input_dtype=get_current_vllm_config().model_config.dtype, ) def forward(self, hidden_states, residual): diff --git a/tests/compile/passes/test_fusion.py b/tests/compile/passes/test_fusion.py index 368ddc8f3bed..79e63efdfe40 100644 --- a/tests/compile/passes/test_fusion.py +++ b/tests/compile/passes/test_fusion.py @@ -9,7 +9,7 @@ import vllm.ir.ops import vllm.plugins from tests.compile.backend import TestBackend -from tests.utils import TestBlockFP8Layer, TestFP8Layer +from tests.utils import TestFP8Layer from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS from vllm.compilation.passes.fusion.rms_quant_fusion import ( @@ -28,19 +28,23 @@ VllmConfig, ) from vllm.model_executor.kernels.linear import ( + AiterFp8BlockScaledMMKernel, ChannelWiseTorchFP8ScaledMMLinearKernel, + CutlassFp8BlockScaledMMKernel, CutlassFP8ScaledMMLinearKernel, + DeepGemmFp8BlockScaledMMKernel, + FlashInferFp8DeepGEMMDynamicBlockScaledKernel, FlashInferFP8ScaledMMLinearKernel, - FP8ScaledMMLinearKernel, PerTensorTorchFP8ScaledMMLinearKernel, ROCmFP8ScaledMMLinearKernel, RowWiseTorchFP8ScaledMMLinearKernel, + TritonFp8BlockScaledMMKernel, + _KernelT, ) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, - QuantKey, - ScaleDesc, + create_fp8_quant_key, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_block_fp8_supported, @@ -66,9 +70,12 @@ (PerTensorTorchFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR), # ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN), - # Blockwise group shapes (no kernel abstraction) - (None, GroupShape(1, 128)), - (None, GroupShape(1, 64)), + # Blockwise group shapes + (FlashInferFp8DeepGEMMDynamicBlockScaledKernel, GroupShape(1, 128)), + (CutlassFp8BlockScaledMMKernel, GroupShape(1, 128)), + (DeepGemmFp8BlockScaledMMKernel, GroupShape(1, 128)), + (TritonFp8BlockScaledMMKernel, GroupShape(1, 128)), + (TritonFp8BlockScaledMMKernel, GroupShape(1, 64)), ] # ROCm kernels @@ -80,8 +87,8 @@ # ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN), # Blockwise group shapes (no kernel abstraction) - (None, GroupShape(1, 128)), - (None, GroupShape(1, 64)), + (TritonFp8BlockScaledMMKernel, GroupShape(1, 128)), + (TritonFp8BlockScaledMMKernel, GroupShape(1, 64)), ] KERNEL_GROUPSHAPE_COMBINATIONS = ( @@ -100,8 +107,8 @@ # Per-token with ChannelWiseTorchFP8ScaledMMLinearKernel (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True), (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False), - # Blockwise (no kernel abstraction) - (None, GroupShape(1, 128), True), + # Blockwise + (AiterFp8BlockScaledMMKernel, GroupShape(1, 128), True), ] @@ -110,8 +117,9 @@ def __init__( self, hidden_size: int, eps: float, - force_kernel: FP8ScaledMMLinearKernel | None, + force_kernel: type[_KernelT] | None, group_shape: GroupShape, + dtype: torch.dtype, use_aiter_fusion: bool = False, use_aiter_quant: bool = False, *args, @@ -129,54 +137,42 @@ def __init__( is_blockwise = group_shape.is_per_group() if is_blockwise: - act_quant_scale_desc = ScaleDesc(torch.float32, False, group_shape) - self.activation_quant_key = QuantKey( - dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True + block_size = group_shape.col + self.activation_quant_key = create_fp8_quant_key( + static=False, group_shape=group_shape ) - self.fp8_linear_layers = [ - TestBlockFP8Layer( - weight_shape=(hidden_size, hidden_size), - group_shape=group_shape, - cutlass_block_fp8_supported=cutlass_block_fp8_supported(), - use_aiter_and_is_supported=use_aiter_quant, - transpose_weights=use_aiter_fusion, - ) - for _ in range(3) - ] - - self.enable_quant_fp8_custom_op = ( - False - if use_aiter_quant - else self.fp8_linear_layers[0].linear_op.input_quant_op.enabled() + self.weight_quant_key = create_fp8_quant_key( + static=True, group_shape=GroupShape(block_size, block_size) ) else: is_static = group_shape == GroupShape.PER_TENSOR - act_quant_scale_desc = ScaleDesc(torch.float32, is_static, group_shape) - w_quant_scale_desc = ScaleDesc(torch.float32, True, group_shape) - self.activation_quant_key = QuantKey( - dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True + self.activation_quant_key = create_fp8_quant_key( + is_static, group_shape=group_shape ) - self.weight_quant_key = QuantKey( - dtype=FP8_DTYPE, scale=w_quant_scale_desc, symmetric=True + self.weight_quant_key = create_fp8_quant_key( + static=True, group_shape=group_shape ) - self.fp8_linear_layers = [ - TestFP8Layer( - weight_shape=(hidden_size, hidden_size), - activation_quant_key=self.activation_quant_key, - weight_quant_key=self.weight_quant_key, - force_kernel=force_kernel, - ) - for _ in range(3) - ] - # Enable aiter quantization if requested - for layer in self.fp8_linear_layers: - layer.kernel.quant_fp8.use_aiter = use_aiter_quant + self.fp8_linear_layers = [ + TestFP8Layer( + weight_shape=(hidden_size, hidden_size), + activation_quant_key=self.activation_quant_key, + weight_quant_key=self.weight_quant_key, + force_kernel=force_kernel, + transpose_weights=use_aiter_fusion, + input_dtype=dtype, + ) + for _ in range(3) + ] + + # Enable aiter quantization if requested + for layer in self.fp8_linear_layers: + layer.kernel.quant_fp8.use_aiter = use_aiter_quant - self.enable_quant_fp8_custom_op = self.fp8_linear_layers[ - 0 - ].is_quant_fp8_enabled() + self.enable_quant_fp8_custom_op = self.fp8_linear_layers[ + 0 + ].is_quant_fp8_enabled() def forward(self, x): # avoid having graph input be an arg to a pattern directly @@ -354,6 +350,7 @@ def test_fusion_rmsnorm_quant( eps=eps, force_kernel=force_kernel, group_shape=group_shape, + dtype=dtype, use_aiter_fusion=False, use_aiter_quant=False, ) @@ -426,6 +423,7 @@ def test_aiter_fusion_rmsnorm_quant( eps=eps, force_kernel=force_kernel, group_shape=group_shape, + dtype=dtype, use_aiter_fusion=True, # Always use aiter fusion ops in aiter test use_aiter_quant=use_aiter_quant_op, # Toggle aiter quantization ) diff --git a/tests/compile/passes/test_fusion_attn.py b/tests/compile/passes/test_fusion_attn.py index 2c5ac7b0b614..2bbf0bda6262 100644 --- a/tests/compile/passes/test_fusion_attn.py +++ b/tests/compile/passes/test_fusion_attn.py @@ -66,6 +66,7 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.device = device self.vllm_config = vllm_config + self.dtype = vllm_config.model_config.dtype self.attn = Attention( num_heads=self.num_qo_heads, @@ -155,6 +156,7 @@ def __init__(self, *args, **kwargs): activation_quant_key=self.quant_key, weight_quant_key=self.quant_key, device=self.device, + input_dtype=self.dtype, ) w = kwargs.get("w") diff --git a/tests/compile/passes/test_mla_attn_quant_fusion.py b/tests/compile/passes/test_mla_attn_quant_fusion.py index 426fbb6a7e57..ce1fa642ad51 100644 --- a/tests/compile/passes/test_mla_attn_quant_fusion.py +++ b/tests/compile/passes/test_mla_attn_quant_fusion.py @@ -74,6 +74,7 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.device = device self.vllm_config = vllm_config + self.dtype = vllm_config.model_config.dtype # Create kv_b_proj (ColumnParallelLinear) on device. # Reuse weights from prior model instance when available, because @@ -190,6 +191,7 @@ def __init__(self, *args, **kwargs): activation_quant_key=self.quant_key, weight_quant_key=self.quant_key, device=self.device, + input_dtype=self.dtype, ) w = kwargs.get("w") diff --git a/tests/compile/passes/test_silu_mul_quant_fusion.py b/tests/compile/passes/test_silu_mul_quant_fusion.py index 383d59d03a7d..f3d800b2815c 100644 --- a/tests/compile/passes/test_silu_mul_quant_fusion.py +++ b/tests/compile/passes/test_silu_mul_quant_fusion.py @@ -36,9 +36,9 @@ ) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, + create_fp8_quant_key, kFp8Dynamic128Sym, kFp8StaticTensorSym, kNvfp4Dynamic, @@ -58,7 +58,11 @@ class TestSiluMulFp8QuantModel(torch.nn.Module): quant_key = kFp8StaticTensorSym def __init__( - self, hidden_size: int, force_kernel: FP8ScaledMMLinearKernel, **kwargs + self, + hidden_size: int, + force_kernel: FP8ScaledMMLinearKernel, + dtype: torch.dtype, + **kwargs, ): super().__init__() self.silu_and_mul = SiluAndMul() @@ -68,6 +72,7 @@ def __init__( activation_quant_key=self.quant_key, weight_quant_key=self.quant_key, force_kernel=force_kernel, + input_dtype=dtype, ) self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() @@ -137,14 +142,20 @@ def ops_in_model_after(self): class TestSiluMulGroupFp8QuantModel(torch.nn.Module): - def __init__(self, hidden_size: int, **kwargs): + act_quant_key = kFp8Dynamic128Sym + + def __init__(self, hidden_size: int, dtype: torch.dtype, **kwargs): super().__init__() self.silu_and_mul = SiluAndMul() - self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( - weight_group_shape=GroupShape(128, 128), - act_quant_group_shape=GroupShape(1, 128), - cutlass_block_fp8_supported=False, - use_aiter_and_is_supported=True, + self.weight_quant_key = create_fp8_quant_key( + static=True, group_shape=GroupShape(hidden_size, hidden_size) + ) + + self.w8a8_block_fp8_linear = TestFP8Layer( + weight_shape=(hidden_size, hidden_size), + weight_quant_key=self.weight_quant_key, + activation_quant_key=self.act_quant_key, + input_dtype=dtype, ) self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() @@ -157,7 +168,7 @@ def __init__(self, hidden_size: int, **kwargs): def forward(self, x): y = self.silu_and_mul(x) - x2 = self.w8a8_block_fp8_linear.apply(y, self.w, self.wscale) + x2 = self.w8a8_block_fp8_linear(y, self.w, self.wscale) return x2 def ops_in_model_before(self): @@ -324,7 +335,9 @@ def test_fusion_silu_and_mul_quant( passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)] backend = TestBackend(*passes) - model = model_class(hidden_size=hidden_size, force_kernel=force_kernel, x=x) + model = model_class( + hidden_size=hidden_size, force_kernel=force_kernel, x=x, dtype=dtype + ) # First dimension dynamic torch._dynamo.mark_dynamic(x, 0) diff --git a/tests/conftest.py b/tests/conftest.py index 38f2bc097f8d..a666c5a86637 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -246,8 +246,9 @@ def default_vllm_config(): """ from vllm.config import VllmConfig, set_current_vllm_config - with set_current_vllm_config(VllmConfig()): - yield + config = VllmConfig() + with set_current_vllm_config(config): + yield config @pytest.fixture() diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 936516576ce1..4cb638e47af0 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -12,8 +12,8 @@ native_w8a8_block_matmul, ) from vllm.config import VllmConfig +from vllm.model_executor.kernels.linear.scaled_mm.cutlass import cutlass_scaled_mm from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_triton_block_scaled_mm, ) @@ -202,7 +202,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): # only aligned sizes are supported by deepgemm if not should_use_deepgemm_for_fp8_linear( - output_dtype=out_dtype, weight=B_fp32, supports_deep_gemm=True + output_dtype=out_dtype, weight_shape=B_fp32.shape, supports_deep_gemm=True ): pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}") diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index badcac005467..bd8c85e95bdd 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -16,6 +16,9 @@ ) from tests.models.utils import check_logprobs_close +from vllm.model_executor.kernels.linear import ( + Fp8BlockScaledMMLinearKernel, +) from vllm.model_executor.layers.fused_moe import UnquantizedFusedMoEMethod from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsConfig, @@ -29,7 +32,6 @@ CompressedTensorsWNA16, ) from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( cutlass_fp4_supported, ) @@ -473,16 +475,14 @@ def check_model(model): qkv_proj = layer.self_attn.qkv_proj assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8) - assert isinstance( - qkv_proj.scheme.w8a8_block_fp8_linear, W8A8BlockFp8LinearOp - ) + assert isinstance(qkv_proj.scheme.fp8_linear, Fp8BlockScaledMMLinearKernel) assert qkv_proj.weight.dtype is fp8_dtype assert qkv_proj.weight_scale.dtype is torch.float32 assert len(qkv_proj.weight.shape) == 2 assert len(qkv_proj.weight_scale.shape) == 2 - input_quant_op = qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op + input_quant_op = qkv_proj.scheme.fp8_linear.quant_fp8 assert isinstance(input_quant_op, QuantFP8) assert input_quant_op._forward_method in ( input_quant_op.forward_cuda, diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 6f8e0f87b890..4209d59ba286 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -13,6 +13,7 @@ from tests.quantization.utils import is_quant_method_supported from vllm import _custom_ops as ops +from vllm.config.model import ModelConfig from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.quantization.fp8 import ( Fp8Config, @@ -406,6 +407,8 @@ def test_fp8_reloading( "If this is your use case, consider using a restore function like #26327" ) + # Set model config as model_config.dtype is required in Fp8LinearMethod. + default_vllm_config.model_config = ModelConfig() with torch.device("cuda:0"): config = Fp8Config( is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, diff --git a/tests/quantization/test_modelopt.py b/tests/quantization/test_modelopt.py index 154b29d7017a..120b2cde0f35 100644 --- a/tests/quantization/test_modelopt.py +++ b/tests/quantization/test_modelopt.py @@ -12,6 +12,7 @@ import torch from tests.quantization.utils import is_quant_method_supported +from vllm.config.model import ModelConfig @pytest.fixture(scope="function", autouse=True) @@ -46,7 +47,7 @@ def _snapshot_download_or_skip(model_id: str) -> str: not is_quant_method_supported("modelopt"), reason="ModelOpt FP8 is not supported on this GPU type.", ) -def test_modelopt_fp8_checkpoint_setup(vllm_runner): +def test_modelopt_fp8_checkpoint_setup(default_vllm_config, vllm_runner): """Test ModelOpt FP8 checkpoint loading and structure validation.""" # TODO: provide a small publicly available test checkpoint model_path = ( @@ -61,6 +62,8 @@ def test_modelopt_fp8_checkpoint_setup(vllm_runner): "This test requires a local ModelOpt FP8 checkpoint." ) + # Set model config as model_config.dtype is required in ModelOptFp8LinearMethod. + default_vllm_config.model_config = ModelConfig() with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm: def check_model(model): @@ -120,11 +123,13 @@ def check_model(model): not is_quant_method_supported("modelopt"), reason="ModelOpt FP8 is not supported on this GPU type.", ) -def test_modelopt_fp8_pc_pt_checkpoint_setup(vllm_runner): +def test_modelopt_fp8_pc_pt_checkpoint_setup(default_vllm_config, vllm_runner): """Test ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoint setup.""" model_id = "CedricHwang/qwen2.5-0.5b-modelopt-fp8-pc-pt" model_path = _snapshot_download_or_skip(model_id) + # Set model config as model_config.dtype is required in ModelOptFp8LinearMethod. + default_vllm_config.model_config = ModelConfig() with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm: def check_model(model): @@ -181,11 +186,13 @@ def check_model(model): not is_quant_method_supported("modelopt"), reason="ModelOpt FP8 is not supported on this GPU type.", ) -def test_modelopt_fp8_pb_wo_checkpoint_setup(vllm_runner): +def test_modelopt_fp8_pb_wo_checkpoint_setup(default_vllm_config, vllm_runner): """Test ModelOpt FP8_PB_WO checkpoint setup.""" model_id = "CedricHwang/qwen2.5-0.5b-modelopt-fp8-pb-wo" model_path = _snapshot_download_or_skip(model_id) + # Set model config as model_config.dtype is required in ModelOptFp8LinearMethod. + default_vllm_config.model_config = ModelConfig() with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm: def check_model(model): diff --git a/tests/utils.py b/tests/utils.py index 84a061f295aa..7af72cb730b0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -43,12 +43,10 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.cli.serve import ServeSubcommand from vllm.model_executor.kernels.linear import ( - FP8ScaledMMLinearKernel, + _KernelT, init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, QuantKey, ) from vllm.model_executor.model_loader import get_model_loader @@ -1811,31 +1809,52 @@ def __init__( weight_shape: tuple[int, int], activation_quant_key: QuantKey, weight_quant_key: QuantKey, + input_dtype: torch.dtype, out_dtype: torch.dtype | None = None, + transpose_weights: bool = False, device: torch.device | None = None, - force_kernel: FP8ScaledMMLinearKernel | None = None, + force_kernel: type[_KernelT] | None = None, ): super().__init__() - per_tensor_weights = weight_quant_key.scale.group_shape.is_per_tensor() - is_static_activation_scale = activation_quant_key.scale.static - weight_scale_shape = (1,) if per_tensor_weights else (weight_shape[0], 1) - - self.weight_scale = torch.rand( - weight_scale_shape, dtype=torch.float32, device=device - ) - self.input_scale = ( - torch.rand(1, dtype=torch.float32, device=device) - if is_static_activation_scale - else None - ) - self.weight = torch.rand(weight_shape, device=device).to(dtype=FP8_DTYPE).t() - self.input_scale_ub = None + act_scale_desc = activation_quant_key.scale + weight_scale_desc = weight_quant_key.scale + is_block_wise = act_scale_desc.group_shape.is_per_group() + if is_block_wise: + block_size = weight_scale_desc.group_shape.col + weight_scale_shape = weight_shape[0] // block_size + self.weight_scale_inv = torch.rand( + (weight_scale_shape, weight_scale_shape), dtype=torch.float32 + ) + self.weight = torch.rand(weight_shape).to(dtype=FP8_DTYPE) + self.input_scale = None + self.weight_scale = None + if transpose_weights: + self.weight = self.weight.t() + else: + per_tensor_weights = weight_scale_desc.group_shape.is_per_tensor() + is_static_activation_scale = act_scale_desc.static + weight_scale_shape = (1,) if per_tensor_weights else (weight_shape[0], 1) + self.weight_scale_inv = None + self.weight_scale = torch.rand( + weight_scale_shape, dtype=torch.float32, device=device + ) + self.input_scale = ( + torch.rand(1, dtype=torch.float32, device=device) + if is_static_activation_scale + else None + ) + self.weight = ( + torch.rand(weight_shape, device=device).to(dtype=FP8_DTYPE).t() + ) + self.input_scale_ub = None out_dtype = torch.get_default_dtype() if out_dtype is None else out_dtype self.kernel = init_fp8_linear_kernel( activation_quant_key=activation_quant_key, weight_quant_key=weight_quant_key, + weight_shape=weight_shape, + input_dtype=input_dtype, out_dtype=out_dtype, force_kernel=force_kernel, ) @@ -1847,61 +1866,3 @@ def forward( self, y: torch.Tensor, bias: torch.Tensor | None = None ) -> torch.Tensor: return self.kernel.apply_weights(self, y, bias) - - -# TODO: Drop TestBlockFP8Layer in favour of a unified TestFP8Layer -# after refactoring W8A8BlockFp8LinearOp. -# https://github.com/vllm-project/vllm/issues/31818 -class TestBlockFP8Layer: - """ - Test helper for blockwise FP8 linear operations. Creates random weights - and scales for W8A8BlockFp8LinearOp. - - This is a workaround until W8A8BlockFp8LinearOp implements the kernel - abstraction (ScaledMMLinearKernel) for blockwise quantization. - - Args: - weight_shape: Shape of the weight tensor (out_features, in_features). - group_shape: Blockwise quantization group shape. - cutlass_block_fp8_supported: Whether CUTLASS blockwise FP8 is available. - use_aiter_and_is_supported: Whether to use aiter quantization ops. - transpose_weights: Whether to transpose weights after creation. - """ - - def __init__( - self, - weight_shape: tuple[int, int], - group_shape: GroupShape, - cutlass_block_fp8_supported: bool = False, - use_aiter_and_is_supported: bool = False, - transpose_weights: bool = False, - ): - weight_scale_shape = weight_shape[0] // group_shape[1] - self.weight_scale = torch.rand( - (weight_scale_shape, weight_scale_shape), dtype=torch.float32 - ) - self.weight = torch.rand(weight_shape).to(dtype=FP8_DTYPE) - self.input_scale = None - if transpose_weights: - self.weight = self.weight.t() - - self.linear_op = W8A8BlockFp8LinearOp( - weight_group_shape=GroupShape(group_shape[1], group_shape[1]), - act_quant_group_shape=group_shape, - cutlass_block_fp8_supported=cutlass_block_fp8_supported, - use_aiter_and_is_supported=use_aiter_and_is_supported, - ) - - def __call__( - self, y: torch.Tensor, bias: torch.Tensor | None = None - ) -> torch.Tensor: - return self.linear_op.apply( - input=y, - weight=self.weight, - weight_scale=self.weight_scale, - input_scale=self.input_scale, - bias=bias, - ) - - def is_quant_fp8_enabled(self) -> bool: - return self.linear_op.input_quant_op.enabled() diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index dee7cdde744d..ff6f45b3dee4 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -998,11 +998,11 @@ def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any: ) hash_content = [] for filepath in forward_code_files: - hash_content.append(filepath) if filepath == "": # This means the function was dynamically generated, with # e.g. exec(). We can't actually check these. continue + hash_content.append(filepath) try: with open(filepath) as f: hash_content.append(f.read()) diff --git a/vllm/model_executor/kernels/linear/__init__.py b/vllm/model_executor/kernels/linear/__init__.py index cfef32056706..774e92c228bf 100644 --- a/vllm/model_executor/kernels/linear/__init__.py +++ b/vllm/model_executor/kernels/linear/__init__.py @@ -19,6 +19,10 @@ import vllm.envs as envs from vllm.logger import init_logger +from vllm.model_executor.kernels.linear.base import ( + MMLinearKernel, + MMLinearLayerConfig, +) from vllm.model_executor.kernels.linear.mixed_precision import ( MPLinearKernel, MPLinearLayerConfig, @@ -52,24 +56,30 @@ XPUwNa16LinearKernel, ) from vllm.model_executor.kernels.linear.scaled_mm import ( + Fp8BlockScaledMMLinearKernel, FP8ScaledMMLinearKernel, FP8ScaledMMLinearLayerConfig, Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig, ScaledMMLinearKernel, - ScaledMMLinearLayerConfig, ) from vllm.model_executor.kernels.linear.scaled_mm.aiter import ( + AiterFp8BlockScaledMMKernel, AiterInt8ScaledMMLinearKernel, ) from vllm.model_executor.kernels.linear.scaled_mm.cpu import ( CPUInt8ScaledMMLinearKernel, ) from vllm.model_executor.kernels.linear.scaled_mm.cutlass import ( + CutlassFp8BlockScaledMMKernel, CutlassFP8ScaledMMLinearKernel, CutlassInt8ScaledMMLinearKernel, ) +from vllm.model_executor.kernels.linear.scaled_mm.deep_gemm import ( + DeepGemmFp8BlockScaledMMKernel, +) from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import ( + FlashInferFp8DeepGEMMDynamicBlockScaledKernel, FlashInferFP8ScaledMMLinearKernel, ) from vllm.model_executor.kernels.linear.scaled_mm.marlin import ( @@ -84,6 +94,7 @@ ROCmFP8ScaledMMLinearKernel, ) from vllm.model_executor.kernels.linear.scaled_mm.triton import ( + TritonFp8BlockScaledMMKernel, TritonInt8ScaledMMLinearKernel, ) from vllm.model_executor.kernels.linear.scaled_mm.xpu import ( @@ -128,6 +139,23 @@ ], } + +# in priority/performance order (when available) +_POSSIBLE_FP8_BLOCK_KERNELS: dict[ + PlatformEnum, list[type[Fp8BlockScaledMMLinearKernel]] +] = { + PlatformEnum.CUDA: [ + FlashInferFp8DeepGEMMDynamicBlockScaledKernel, + DeepGemmFp8BlockScaledMMKernel, + CutlassFp8BlockScaledMMKernel, + TritonFp8BlockScaledMMKernel, + ], + PlatformEnum.ROCM: [ + AiterFp8BlockScaledMMKernel, + TritonFp8BlockScaledMMKernel, + ], +} + # in priority/performance order (when available) _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = { PlatformEnum.CUDA: [ @@ -152,8 +180,10 @@ ], } -_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel) -_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig) +# TODO make all kernels inherit from MMLinearKernel +# then bound _KernelT only to MMLinearKernel +_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel | MMLinearKernel) +_KernelConfigT = TypeVar("_KernelConfigT", bound=MMLinearLayerConfig) def is_supported_and_can_implement_kernel( @@ -243,32 +273,61 @@ def choose_scaled_mm_linear_kernel( def init_fp8_linear_kernel( activation_quant_key: QuantKey, weight_quant_key: QuantKey, + weight_shape: tuple[int, int], + input_dtype: torch.dtype, out_dtype: torch.dtype, - force_kernel: type[FP8ScaledMMLinearKernel] | None = None, + force_kernel: type[_KernelT] | None = None, module_name: str | None = None, -) -> FP8ScaledMMLinearKernel: +) -> FP8ScaledMMLinearKernel | Fp8BlockScaledMMLinearKernel: scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( weight_quant_key=weight_quant_key, activation_quant_key=activation_quant_key, + weight_shape=weight_shape, + input_dtype=input_dtype, out_dtype=out_dtype, ) - kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, force_kernel=force_kernel - ) + if activation_quant_key.scale.group_shape.is_per_group(): + kernel_type = choose_scaled_mm_linear_kernel( + config=scaled_mm_linear_kernel_config, + possible_kernels=_POSSIBLE_FP8_BLOCK_KERNELS, # type: ignore[misc] + force_kernel=force_kernel, + ) + if module_name: + logger.info_once( + "Selected %s for %s", + kernel_type.__name__, + module_name, + scope="global", + ) - if module_name: - logger.info_once( - "Selected %s for %s", - kernel_type.__name__, - module_name, - scope="global", + return kernel_type( + scaled_mm_linear_kernel_config, ) - return kernel_type( - scaled_mm_linear_kernel_config, - layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"], - ) + else: + kernel_type = choose_scaled_mm_linear_kernel( + config=scaled_mm_linear_kernel_config, + possible_kernels=_POSSIBLE_FP8_KERNELS, # type: ignore[misc] + force_kernel=force_kernel, + ) + if module_name: + logger.info_once( + "Selected %s for %s", + kernel_type.__name__, + module_name, + scope="global", + ) + + return kernel_type( + scaled_mm_linear_kernel_config, + layer_param_names=[ + "weight", + "weight_scale", + "input_scale", + "input_scale_ub", + ], + ) def init_int8_linear_kernel( @@ -433,4 +492,7 @@ def register_linear_kernel( "MarlinLinearKernel", "XPUW4A8IntLinearKernel", "XPUwNa16LinearKernel", + "_KernelT", + "DeepGemmFp8BlockScaledMMKernel", + "FlashInferFp8DeepGEMMDynamicBlockScaledKernel", ] diff --git a/vllm/model_executor/kernels/linear/base.py b/vllm/model_executor/kernels/linear/base.py new file mode 100644 index 000000000000..4e9b89bb3ff1 --- /dev/null +++ b/vllm/model_executor/kernels/linear/base.py @@ -0,0 +1,324 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, ClassVar, Generic, TypeVar + +import torch +from typing_extensions import Self + + +@dataclass +class MMLinearLayerConfig: ... + + +@dataclass +class Params: + """Base class for quantized layer parameters. + + This class provides a typed interface for accessing quantized weights and scales + from layer modules. It serves as a parameter container that can be extracted from + layers and passed to kernel implementations. + + Attributes: + weight: The quantized weight tensor + weight_scale: weight scaling factors + input_scale: Optional input scaling factors + + Class Variables: + WEIGHT: Attribute name for weight tensor on the layer module + WEIGHT_SCALE: Attribute name for weight scale tensor on the layer module + INPUT_SCALE: Attribute name for input scale tensor on the layer module + + Important: + The string values of WEIGHT, WEIGHT_SCALE, and INPUT_SCALE class variables + MUST match the attribute names used in the corresponding quantization method's + create_weights() implementation. + For example, if FP8LinearMethod.create_weights() + sets layer.weight and layer.weight_scale, + then WEIGHT="weight" and + WEIGHT_SCALE="weight_scale" must be used here. + + Usage: + ```python + # Extract parameters from a quantized layer + params = Params.from_layer(layer) + + # Access typed parameters + output = func(input, params.weight, params.weight_scale) + ``` + """ + + weight: torch.Tensor + weight_scale: torch.Tensor + input_scale: torch.Tensor | None + + # Attribute names on the layer + WEIGHT: ClassVar[str] = "weight" + WEIGHT_SCALE: ClassVar[str] = "weight_scale" + INPUT_SCALE: ClassVar[str] = "input_scale" + + @classmethod + def from_layer(cls, layer: torch.nn.Module) -> Self: + return cls( + weight=getattr(layer, cls.WEIGHT), + weight_scale=getattr(layer, cls.WEIGHT_SCALE), + input_scale=getattr(layer, cls.INPUT_SCALE, None), + ) + + +@dataclass +class FP8Params(Params): + """FP8 layer parameters with typed fields""" + + input_scale_ub: torch.Tensor | None + + INPUT_SCALE_UB: ClassVar[str] = "input_scale_ub" + + @classmethod + def from_layer(cls, layer: torch.nn.Module) -> "FP8Params": + """Extract parameters from layer""" + return cls( + weight=getattr(layer, cls.WEIGHT), + weight_scale=getattr(layer, cls.WEIGHT_SCALE), + input_scale=getattr(layer, cls.INPUT_SCALE, None), + input_scale_ub=getattr(layer, cls.INPUT_SCALE_UB, None), + ) + + +@dataclass +class Int8Params(Params): + """Int8 layer parameters with typed fields""" + + input_zero_point: torch.Tensor | None + azp_adj: torch.Tensor | None + + INPUT_ZERO_POINT: ClassVar[str] = "input_zero_point" + AZP_ADJ: ClassVar[str] = "azp_adj" + + @classmethod + def from_layer(cls, layer: torch.nn.Module) -> "Int8Params": + """Extract parameters from layer""" + return cls( + weight=getattr(layer, cls.WEIGHT), + weight_scale=getattr(layer, cls.WEIGHT_SCALE), + input_scale=getattr(layer, cls.INPUT_SCALE, None), + input_zero_point=getattr(layer, cls.INPUT_ZERO_POINT, None), + azp_adj=getattr(layer, cls.AZP_ADJ, None), + ) + + +_ParamsT = TypeVar("_ParamsT", bound=Params) +_ConfigT = TypeVar("_ConfigT", bound=MMLinearLayerConfig) + + +class MMLinearKernel(ABC, Generic[_ConfigT, _ParamsT]): + """Abstract base class for quantized matrix multiplication kernels. + + This class provides the interface for implementing custom quantized linear layer + kernels in vLLM. Subclasses should implement specific quantization strategies + (e.g., FP8, INT8) and their corresponding compute kernels. + + Generic Type Parameters: + _ConfigT: Configuration type for the kernel (subclass of MMLinearLayerConfig). + Contains kernel-specific settings like quantization keys, dtypes, etc. + _ParamsT: Parameter type for the kernel (subclass of Params). + Defines the quantized weights and scales needed by the kernel. + + Typical Usage: + 1. Define a config dataclass inheriting from MMLinearLayerConfig + 2. Define a params dataclass inheriting from Params (or FP8Params/Int8Params) + 3. Subclass MMLinearKernel with your config and params types + 4. Implement all abstract methods + 5. Register the kernel with the quantization method + + Example: + ```python + @dataclass + class MyKernelConfig(MMLinearLayerConfig): + static: bool + output_dtype: torch.dtype + + + @dataclass + class MyKernelParams(FP8Params): + custom_scale: torch.Tensor + CUSTOM_SCALE: ClassVar[str] = "custom_scale" + + + class MyKernel(MMLinearKernel[MyKernelConfig, MyKernelParams]): + @classmethod + def is_supported(cls, compute_capability=None): + if compute_capability and compute_capability < 90: + return False, "Requires compute capability >= 9.0" + return True, None + + @classmethod + def can_implement(cls, config): + if not config.static: + return False, "Only static quantization supported" + return True, None + + def process_weights_after_loading(self, layer): + # Preprocess weights for the kernel + params = self._get_layer_params(layer) + processed = preprocess_weights(params.weight) + replace_parameter(layer, params.WEIGHT, processed) + + def _get_layer_params(self, layer, **kwargs): + return MyKernelParams.from_layer(layer) + + def apply_weights(self, layer, x, bias=None, **kwargs): + params = self._get_layer_params(layer) + # Call your custom kernel + output = my_custom_kernel(x, params.weight, params.weight_scale) + if bias is not None: + output += bias + return output + ``` + + Lifecycle: + 1. Kernel selection: is_supported() and can_implement() check compatibility + 2. Initialization: __init__() creates kernel instance with config + 3. Weight loading: process_weights_after_loading() preprocesses weights + 4. Inference: apply_weights() executes the quantized matmul + """ + + @classmethod + @abstractmethod + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + """Check if this kernel is supported on the current hardware. + + This method checks hardware-level compatibility (e.g., GPU architecture, + compute capability, available instructions). It's called during kernel + selection to filter out kernels that cannot run on the current device. + + Args: + compute_capability: GPU compute capability (e.g., 80 for A100, 90 for H100). + If None, should check the current device. + + Returns: + A tuple of (is_supported, reason): + - is_supported: True if the kernel can run on this hardware + - reason: If not supported, a string explaining why; otherwise None + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def can_implement(cls, config: _ConfigT) -> tuple[bool, str | None]: + """Check if this kernel can implement the given configuration. + + This method checks configuration-level compatibility (e.g., quantization + scheme, group sizes, static vs dynamic quantization). It's called after + is_supported() to determine if this kernel can handle the specific + quantization configuration. + + Args: + config: The kernel configuration to check + + Returns: + A tuple of (can_implement, reason): + - can_implement: True if this kernel supports the config + - reason: If not supported, a string explaining why; otherwise None + ``` + """ + raise NotImplementedError + + def __init__(self, config: _ConfigT) -> None: + """Initialize the kernel with the given configuration. + + Args: + config: Kernel-specific configuration containing settings like + quantization keys, output dtypes, etc. + """ + self.config = config + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Process and transform weights after loading from checkpoint. + + This method is called once after weights are loaded but before inference. + Use it to preprocess weights into the format required by your kernel + (e.g., reordering, padding, format conversion). + + Modifications should be done in-place using replace_parameter() to ensure + the layer's parameters are properly updated. + + Args: + layer: The layer module containing the weights to process + + Example: + ```python + def process_weights_after_loading(self, layer): + params = self._get_layer_params(layer) + # Reorder weights for better memory access + weight_reordered = reorder_weights(params.weight) + replace_parameter(layer, params.WEIGHT, weight_reordered) + ``` + """ + raise NotImplementedError + + # return a covariant type in the subclass + @abstractmethod + def _get_layer_params(self, layer: torch.nn.Module, **kwargs: Any) -> _ParamsT: + """Extract typed parameters from the layer module. + + This internal method retrieves the quantized weights and scales from + the layer as a typed parameter object. Subclasses should typically + delegate to ParamsClass.from_layer(). + + Args: + layer: The layer module containing the parameters + **kwargs: Additional arguments + + Returns: + A typed parameter object containing weights, scales, and other + quantization parameters + + Example: + ```python + def _get_layer_params(self, layer, **kwargs): + return MyKernelParams.from_layer(layer) + ``` + """ + raise NotImplementedError + + def get_output_padding(self) -> int | None: + """Get the number of output tokens to pad for this kernel. + + Some kernels require input padding for optimal performance. + Override this method to specify padding requirements. + + Returns: + Number of tokens to pad, or None for no padding (default) + """ + return None + + @abstractmethod + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + **kwargs: Any, + ) -> torch.Tensor: + """Apply the quantized weights to the input tensor. + + This is the main inference method that performs the quantized matrix + multiplication. It should handle input quantization (if needed), call + the underlying kernel, and apply bias. + + Args: + layer: The layer module containing the quantized weights + x: Input tensor of shape [..., in_features] + bias: Optional bias tensor of shape [out_features] + **kwargs: Additional kernel-specific arguments + + Returns: + Output tensor of shape [..., out_features] + """ + raise NotImplementedError diff --git a/vllm/model_executor/kernels/linear/scaled_mm/BlockScaledMMLinearKernel.py b/vllm/model_executor/kernels/linear/scaled_mm/BlockScaledMMLinearKernel.py new file mode 100644 index 000000000000..d738796ce6bf --- /dev/null +++ b/vllm/model_executor/kernels/linear/scaled_mm/BlockScaledMMLinearKernel.py @@ -0,0 +1,209 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import ClassVar + +import torch +from typing_extensions import Self + +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + process_fp8_weight_block_strategy, +) +from vllm.model_executor.utils import replace_parameter + +from ..base import ( + FP8Params, + MMLinearKernel, +) +from .ScaledMMLinearKernel import FP8ScaledMMLinearLayerConfig + + +@dataclass +class FP8BlockParams(FP8Params): + weight_scale_inv: torch.Tensor | None + weight_scale: torch.Tensor | None + + WEIGHT_SCALE_INV: ClassVar[str] = "weight_scale_inv" + + @classmethod + def from_layer(cls, layer: torch.nn.Module) -> Self: + return cls( + weight=getattr(layer, cls.WEIGHT), + weight_scale_inv=getattr(layer, cls.WEIGHT_SCALE_INV, None), + weight_scale=getattr(layer, cls.WEIGHT_SCALE, None), + input_scale=getattr(layer, cls.INPUT_SCALE, None), + input_scale_ub=getattr(layer, cls.INPUT_SCALE_UB, None), + ) + + +class Fp8BlockScaledMMLinearKernel( + MMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8BlockParams], ABC +): + # Set to False in subclasses that accept BF16 input directly (e.g. FlashInfer) + # and therefore do not need the input quantization step in apply_weights. + apply_input_quant: ClassVar[bool] = True + + def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None: + super().__init__(config) + act_scale_descriptor = config.activation_quant_key.scale + self.weight_group_shape = config.weight_quant_key.scale.group_shape + self.quant_fp8 = QuantFP8( + static=act_scale_descriptor.static, + group_shape=act_scale_descriptor.group_shape, + num_token_padding=self.get_output_padding(), + use_ue8m0=False, + ) + self.use_triton = False + + @classmethod + def can_implement(cls, config: FP8ScaledMMLinearLayerConfig): + act_quant_key = config.activation_quant_key + if act_quant_key.scale.static: + return ( + False, + "Only dynamic per token group activation quantization is supported.", + ) + + return True, None + + def _get_layer_params(self, layer: torch.nn.Module, **kwargs) -> FP8BlockParams: + return FP8BlockParams.from_layer(layer) + + def process_weights_after_loading(self, layer: torch.nn.Module): + params = self._get_layer_params(layer) + # Fp8LinearMethod registered weight scale + # buffer as weight_scale_inv unlike compressed tensors. + weight_scale = ( + params.weight_scale + if params.weight_scale_inv is None + else params.weight_scale_inv + ) + scale_attr_name = ( + params.WEIGHT_SCALE + if params.weight_scale_inv is None + else params.WEIGHT_SCALE_INV + ) + new_weight, new_weight_scale = process_fp8_weight_block_strategy( + params.weight, + weight_scale, + ) + + replace_parameter(layer, params.WEIGHT, new_weight.data) + replace_parameter(layer, scale_attr_name, new_weight_scale.data) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + out_dtype = self.config.out_dtype + params = self._get_layer_params(layer) + weight = params.weight + weight_scale = ( + params.weight_scale + if params.weight_scale_inv is None + else params.weight_scale_inv + ) + input_scale = params.input_scale + scale_up = params.input_scale_ub + + # View input as 2D matrix for fp8 methods + input_2d = x.view(-1, x.shape[-1]) + output_shape = [*x.shape[:-1], weight.shape[0]] + + if self.apply_input_quant: + q_input, input_scale = self.quant_fp8( + input_2d, input_scale, scale_up, use_triton=self.use_triton + ) + else: + q_input = input_2d + # Provide a concrete placeholder so apply_block_scaled_mm args are + # always Tensors. Subclasses with apply_input_quant=False must not + # use As in apply_block_scaled_mm. + input_scale = ( + input_scale if input_scale is not None else input_2d.new_ones(1) + ) + + output = self.apply_block_scaled_mm( + A=q_input, + B=weight, + As=input_scale, + Bs=weight_scale, + ) + + if bias is not None: + output = output + bias + return output.to(dtype=out_dtype).view(*output_shape) + + @abstractmethod + def apply_block_scaled_mm( + self, + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + +class Fp8BlockScaledDynamicMMLinearKernel(Fp8BlockScaledMMLinearKernel, ABC): + """Dynamic FP8 block-scaled kernel that dispatches at runtime. + + Extends Fp8BlockScaledMMLinearKernel to inherit apply_weights and overrides + apply_block_scaled_mm to dispatch between two sub-kernels using torch.cond. + + Subclasses must define: + base_type: The primary kernel class. + fallback_type: The fallback kernel class. + """ + + base_type: ClassVar[type[Fp8BlockScaledMMLinearKernel]] + fallback_type: ClassVar[type[Fp8BlockScaledMMLinearKernel]] + + def __init__(self, config: "FP8ScaledMMLinearLayerConfig") -> None: + super().__init__(config) + self.base = self.base_type(config) + self.fallback = self.fallback_type(config) + + @classmethod + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + is_base_supported, reason_1 = cls.base_type.is_supported(compute_capability) + is_fallback_supported, reason_2 = cls.fallback_type.is_supported( + compute_capability + ) + if is_base_supported and is_fallback_supported: + return True, None + if not is_base_supported and not is_fallback_supported: + return ( + False, + f"base is not supported due to {reason_1}; " + f"fallback is not supported due to {reason_2}", + ) + if not is_base_supported: + return False, f"base is not supported due to {reason_1}" + return False, f"fallback is not supported due to {reason_2}" + + @classmethod + def can_implement( + cls, config: "FP8ScaledMMLinearLayerConfig" + ) -> tuple[bool, str | None]: + can_implement_base, reason_1 = cls.base_type.can_implement(config) + can_implement_fallback, reason_2 = cls.fallback_type.can_implement(config) + if can_implement_base and can_implement_fallback: + return True, None + if not can_implement_base and not can_implement_fallback: + return ( + False, + f"base cannot implement due to {reason_1}; " + f"fallback cannot implement due to {reason_2}", + ) + if not can_implement_base: + return False, f"base cannot implement due to {reason_1}" + return False, f"fallback cannot implement due to {reason_2}" diff --git a/vllm/model_executor/kernels/linear/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/kernels/linear/scaled_mm/ScaledMMLinearKernel.py index cdb69b06f5cd..b9f6f0c8f873 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/ScaledMMLinearKernel.py @@ -14,14 +14,11 @@ ) from vllm.platforms import current_platform - -@dataclass -class ScaledMMLinearLayerConfig: - pass +from ..base import MMLinearLayerConfig @dataclass -class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): +class Int8ScaledMMLinearLayerConfig(MMLinearLayerConfig): # TODO: Change to QuantKey like FP8ScaledMMLinearLayerConfig is_static_input_scheme: bool is_channelwise: bool @@ -29,10 +26,12 @@ class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): @dataclass -class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): +class FP8ScaledMMLinearLayerConfig(MMLinearLayerConfig): weight_quant_key: QuantKey activation_quant_key: QuantKey - out_dtype: torch.dtype | None + weight_shape: tuple[int, int] + input_dtype: torch.dtype + out_dtype: torch.dtype _FP8ParamsT = tuple[ @@ -50,7 +49,7 @@ class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): ] _ParamsT = TypeVar("_ParamsT", _Int8ParamsT, _FP8ParamsT) -_ConfigT = TypeVar("_ConfigT", bound=ScaledMMLinearLayerConfig) +_ConfigT = TypeVar("_ConfigT", bound=MMLinearLayerConfig) class ScaledMMLinearKernel(Generic[_ConfigT, _ParamsT], ABC): diff --git a/vllm/model_executor/kernels/linear/scaled_mm/__init__.py b/vllm/model_executor/kernels/linear/scaled_mm/__init__.py index 2323a02ba593..e86684b2f8a1 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/__init__.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/__init__.py @@ -4,6 +4,9 @@ from vllm.model_executor.kernels.linear.scaled_mm.aiter import ( AiterInt8ScaledMMLinearKernel, ) +from vllm.model_executor.kernels.linear.scaled_mm.BlockScaledMMLinearKernel import ( + Fp8BlockScaledMMLinearKernel, +) from vllm.model_executor.kernels.linear.scaled_mm.cpu import ( CPUInt8ScaledMMLinearKernel, ) @@ -31,7 +34,6 @@ Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig, ScaledMMLinearKernel, - ScaledMMLinearLayerConfig, ) from vllm.model_executor.kernels.linear.scaled_mm.triton import ( TritonInt8ScaledMMLinearKernel, @@ -55,4 +57,5 @@ "RowWiseTorchFP8ScaledMMLinearKernel", "ROCmFP8ScaledMMLinearKernel", "TritonInt8ScaledMMLinearKernel", + "Fp8BlockScaledMMLinearKernel", ] diff --git a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py index 1945a1e4354d..01d2298ed4a0 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py @@ -6,8 +6,15 @@ from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, +) from vllm.platforms import current_platform +from .BlockScaledMMLinearKernel import ( + Fp8BlockScaledMMLinearKernel, + FP8ScaledMMLinearLayerConfig, +) from .cutlass import CutlassInt8ScaledMMLinearKernel from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig @@ -107,3 +114,54 @@ def apply_weights( # b to be [N, K] # CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype) + + +class AiterFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel): + def __init__(self, config: FP8ScaledMMLinearLayerConfig): + super().__init__(config) + n, k = config.weight_shape + + self.use_triton = ( + not current_platform.is_fp8_fnuz() + and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k) + ) + + @classmethod + def is_supported(cls, compute_capability=None): + return ( + rocm_aiter_ops.is_linear_enabled(), + "Only supported on ROCm platform \ + with aiter package installed.", + ) + + @classmethod + def can_implement(cls, config: FP8ScaledMMLinearLayerConfig): + can_implement_base, reason = super().can_implement(config) + if not can_implement_base: + return can_implement_base, reason + + act_quant_desc = config.activation_quant_key.scale + if act_quant_desc.group_shape != GroupShape(1, 128): + return ( + False, + "Supports only dynamic per token group activation " + "quantization with group_shape=(1,128).", + ) + return True, None + + def apply_block_scaled_mm( + self, + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + ) -> torch.Tensor: + out_dtype = self.config.out_dtype + if self.use_triton: + gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale + else: + gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale + + return gemm_a8w8_blockscale_op( + A, B, As, Bs, list(self.weight_group_shape), output_dtype=out_dtype + ) diff --git a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py index bcaf57bcbb26..618084029159 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py @@ -5,12 +5,19 @@ import torch from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + CUTLASS_BLOCK_FP8_SUPPORTED, convert_to_channelwise, ) from vllm.platforms import current_platform +from vllm.utils.torch_utils import direct_register_custom_op +from .BlockScaledMMLinearKernel import Fp8BlockScaledMMLinearKernel from .ScaledMMLinearKernel import ( FP8ScaledMMLinearKernel, FP8ScaledMMLinearLayerConfig, @@ -171,3 +178,143 @@ def apply_scaled_mm( A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias ) return output.view(*output_shape) + + +class CutlassFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel): + def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None: + super().__init__(config) + act_scale_descriptor = config.activation_quant_key.scale + self.weight_group_shape = config.weight_quant_key.scale.group_shape + self.quant_fp8 = QuantFP8( + static=act_scale_descriptor.static, + group_shape=act_scale_descriptor.group_shape, + num_token_padding=self.get_output_padding(), + use_ue8m0=False, + column_major_scales=True, + ) + self.is_hopper = current_platform.is_device_capability(90) + + @classmethod + def is_supported(cls, compute_capability=None): + if not CUTLASS_BLOCK_FP8_SUPPORTED: + return ( + False, + "The device compute capability of" + f"{compute_capability} is not supported.", + ) + return True, None + + @classmethod + def can_implement(cls, config: FP8ScaledMMLinearLayerConfig): + can_implement_base, reason = super().can_implement(config) + if not can_implement_base: + return can_implement_base, reason + + act_quant_desc = config.activation_quant_key.scale + if act_quant_desc.group_shape != GroupShape(1, 128): + return ( + False, + "Supports only dynamic per token group activation " + "quantization with group_shape=(1,128).", + ) + return True, None + + def apply_block_scaled_mm( + self, + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + ) -> torch.Tensor: + out_dtype = self.config.out_dtype + if self.is_hopper: + return torch.ops.vllm.padded_cutlass( + A, + B, + As, + Bs, + list(self.weight_group_shape), + out_dtype, + ) + else: + return ops.cutlass_scaled_mm( + A, + B.T, + out_dtype=out_dtype, + scale_a=As, + scale_b=Bs.T, + ) + + +def cutlass_scaled_mm( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + return ops.cutlass_scaled_mm( + A, + B.T, + out_dtype=output_dtype, + scale_a=As, + scale_b=Bs.T, + ) + + +def _padded_cutlass( + qx: torch.Tensor, + weight: torch.Tensor, + x_scale: torch.Tensor, + weight_scale: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + pad_multiple = 4 + dim = qx.shape[0] + padded = ( + dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple) + ) + + has_pad = padded > dim + + if has_pad: + padded_shape = [padded, *qx.shape[1:]] + padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype) + padded_qx[0 : qx.shape[0], ...].copy_(qx) + + padded_x_scale_shape = [*x_scale.shape[1:], padded] + padded_x_scale = torch.ones( + padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype + ).permute(-1, -2) + padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale) + + output = cutlass_scaled_mm( + padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype + ) + return output[0 : qx.shape[0], ...] + else: + return cutlass_scaled_mm( + qx, weight, x_scale, weight_scale, block_size, output_dtype + ) + + +def _padded_cutlass_fake( + qx: torch.Tensor, + weight: torch.Tensor, + x_scale: torch.Tensor, + weight_scale: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + return torch.empty( + (qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device + ) + + +direct_register_custom_op( + "padded_cutlass", + _padded_cutlass, + fake_impl=_padded_cutlass_fake, +) diff --git a/vllm/model_executor/kernels/linear/scaled_mm/deep_gemm.py b/vllm/model_executor/kernels/linear/scaled_mm/deep_gemm.py new file mode 100644 index 000000000000..a369623a3b17 --- /dev/null +++ b/vllm/model_executor/kernels/linear/scaled_mm/deep_gemm.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +import vllm.envs as envs +from vllm.config import get_current_vllm_config +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + deepgemm_post_process_fp8_weight_block, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, +) +from vllm.model_executor.utils import replace_parameter +from vllm.platforms import current_platform +from vllm.utils.deep_gemm import ( + fp8_gemm_nt, + is_deep_gemm_e8m0_used, + is_deep_gemm_supported, + should_auto_disable_deep_gemm, + should_use_deepgemm_for_fp8_linear, +) +from vllm.utils.torch_utils import direct_register_custom_op + +from .BlockScaledMMLinearKernel import ( + Fp8BlockScaledMMLinearKernel, + FP8ScaledMMLinearLayerConfig, +) + + +class DeepGemmFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel): + def __init__(self, config: FP8ScaledMMLinearLayerConfig): + super().__init__(config) + self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used() + act_scale_descriptor = config.activation_quant_key.scale + self.is_deep_gemm_supported = is_deep_gemm_supported() + self.quant_fp8 = QuantFP8( + static=False, + group_shape=act_scale_descriptor.group_shape, + use_ue8m0=self.use_deep_gemm_e8m0, + tma_aligned_scales=envs.VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES, + column_major_scales=True, + ) + + @classmethod + def is_supported(cls, compute_capability=None): + if not current_platform.is_cuda(): + return False, "DeepGEMM is only supported on cuda platform" + if not is_deep_gemm_supported(): + return False, "Currently, only Hopper and Blackwell GPUs are supported." + return True, None + + @classmethod + def can_implement(cls, config): + can_implement_base, reason = super().can_implement(config) + if not can_implement_base: + return can_implement_base, reason + if config.out_dtype != torch.bfloat16: + return (False, "Supports only output dtype of bfloat16") + + act_quant_desc = config.activation_quant_key.scale + if act_quant_desc.group_shape != GroupShape(1, 128): + return ( + False, + "Supports only dynamic per token group activation " + "quantization with group_shape=(1,128).", + ) + model_config = get_current_vllm_config().model_config + + if model_config is None: + return False, "Model configuration is required." + + model_type = getattr(model_config.hf_text_config, "model_type", None) + if should_auto_disable_deep_gemm(model_type): + return False, f"Should not use deepgemm for model {model_type}" + + if not should_use_deepgemm_for_fp8_linear( + config.out_dtype, config.weight_shape + ): + return False, "The provided metadata is not supported." + return True, None + + def process_weights_after_loading(self, layer): + super().process_weights_after_loading(layer) + params = self._get_layer_params(layer) + assert layer.weight_block_size is not None + + if self.is_deep_gemm_supported: + weight_scale_invs = params.weight_scale_inv + scale_attr = ( + params.WEIGHT_SCALE_INV + if weight_scale_invs is not None + else params.WEIGHT_SCALE + ) + dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block( + wq=params.weight, + ws=weight_scale_invs + if weight_scale_invs is not None + else params.weight_scale, + quant_block_shape=tuple(layer.weight_block_size), + use_e8m0=self.use_deep_gemm_e8m0, + ) + replace_parameter(layer, params.WEIGHT, dg_weight) + replace_parameter(layer, scale_attr, dg_weight_scale) + + def apply_block_scaled_mm( + self, + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + ) -> torch.Tensor: + out_dtype = self.config.out_dtype + output = torch.empty( + (A.shape[0], B.shape[0]), + dtype=out_dtype, + device=A.device, + ) + torch.ops.vllm.fp8_gemm_nt_op(A, As, B, Bs, output, self.use_deep_gemm_e8m0) + return output + + +def _fp8_gemm_nt_op( + q_input: torch.Tensor, + input_scale: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + output: torch.Tensor, + use_deep_gemm_e8m0: bool, +) -> None: + fp8_gemm_nt( + (q_input, input_scale), + (weight, weight_scale), + output, + is_deep_gemm_e8m0_used=use_deep_gemm_e8m0, + ) + + +def _fp8_gemm_nt_op_fake( + q_input: torch.Tensor, + input_scale: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + output: torch.Tensor, + use_deep_gemm_e8m0: bool, +) -> None: + return None + + +direct_register_custom_op( + "fp8_gemm_nt_op", + _fp8_gemm_nt_op, + mutates_args=["output"], + fake_impl=_fp8_gemm_nt_op_fake, +) diff --git a/vllm/model_executor/kernels/linear/scaled_mm/flashinfer.py b/vllm/model_executor/kernels/linear/scaled_mm/flashinfer.py index 991cda862acf..c84fd5dda84e 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/flashinfer.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/flashinfer.py @@ -2,11 +2,32 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import ClassVar + import torch +import vllm.envs as envs +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, +) from vllm.platforms import current_platform -from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer +from vllm.utils.flashinfer import ( + flashinfer_fp8_blockscale_gemm, + flashinfer_scaled_fp8_mm, + has_flashinfer, + is_flashinfer_fp8_blockscale_gemm_supported, + should_use_flashinfer_for_blockscale_fp8_gemm, +) +from vllm.utils.torch_utils import direct_register_custom_op +from .BlockScaledMMLinearKernel import ( + Fp8BlockScaledDynamicMMLinearKernel, + Fp8BlockScaledMMLinearKernel, +) +from .deep_gemm import DeepGemmFp8BlockScaledMMKernel, fp8_gemm_nt from .ScaledMMLinearKernel import ( FP8ScaledMMLinearKernel, FP8ScaledMMLinearLayerConfig, @@ -55,3 +76,256 @@ def apply_scaled_mm( return flashinfer_scaled_fp8_mm( A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias ) + + +class FlashInferFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel): + # FlashInfer accepts BF16 input and handles FP8 conversion internally. + apply_input_quant: ClassVar[bool] = False + + def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None: + super().__init__(config) + + @classmethod + def can_implement(cls, config: FP8ScaledMMLinearLayerConfig): + can_implement_base, reason = super().can_implement(config) + if not can_implement_base: + return can_implement_base, reason + + act_quant_desc = config.activation_quant_key.scale + if act_quant_desc.group_shape != GroupShape(1, 128): + return ( + False, + "Supports only dynamic per token group activation " + "quantization with group_shape=(1,128).", + ) + + if not should_use_flashinfer_for_blockscale_fp8_gemm( + is_flashinfer_fp8_blockscale_gemm_supported(), + config.out_dtype, + config.input_dtype, + config.weight_quant_key.dtype, + config.weight_shape, + ): + return ( + False, + "The provided metadata is not supported.", + ) + + return True, None + + @classmethod + def is_supported(cls, compute_capability=None): + if not current_platform.is_cuda(): + return False, "only cuda devices are supported." + + if not is_flashinfer_fp8_blockscale_gemm_supported(): + return False, "FlashInfer block-scale FP8 GEMM is not available." + + return True, None + + def apply_block_scaled_mm( + self, + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + ) -> torch.Tensor: + # A is BF16 — FlashInfer handles FP8 conversion internally. + # As is a placeholder (apply_input_quant=False) and is not used here. + return torch.ops.vllm.flashinfer_fp8_blockscale_gemm( + A, # BF16 input + B, # FP8 weight + Bs, # Weight scales + ) + + +class FlashInferFp8DeepGEMMDynamicBlockScaledKernel( + Fp8BlockScaledDynamicMMLinearKernel +): + """ + Conditional FlashInfer / DeepGEMM FP8 block-scaled GEMM. + + Dispatches between two kernels based on input batch size: + - Small batches (M < 32): FlashInfer's swapAB trick for better utilisation. + - Large batches (M >= 32): DeepGEMM for peak throughput. + + apply_input_quant is False because FlashInfer accepts BF16 input and + handles FP8 conversion internally. The DeepGEMM branch therefore + quantises BF16→FP8 inside apply_mm via a closure before dispatching to + the DeepGEMM kernel — keeping both branches compatible with the single + BF16 tensor operand list passed by torch.cond. + """ + + base_type: ClassVar[type[FlashInferFp8BlockScaledMMKernel]] = ( + FlashInferFp8BlockScaledMMKernel + ) + fallback_type: ClassVar[type[DeepGemmFp8BlockScaledMMKernel]] = ( + DeepGemmFp8BlockScaledMMKernel + ) + apply_input_quant: ClassVar[bool] = False + + def __init__(self, config: FP8ScaledMMLinearLayerConfig): + super().__init__(config) + self.base: FlashInferFp8BlockScaledMMKernel + self.fallback: DeepGemmFp8BlockScaledMMKernel + + def process_weights_after_loading(self, layer: torch.nn.Module): + # DeepGEMM need post-processing; both kernels share the same + # parameter tensor layout so processing once is sufficient. + self.fallback.process_weights_after_loading(layer) + + def apply_block_scaled_mm( + self, + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + ) -> torch.Tensor: + group_size = self.weight_group_shape.col + use_deep_gemm_e8m0 = self.fallback.use_deep_gemm_e8m0 + + return torch.ops.vllm.dynamic_flashinfer_deepgemm_blockscale_gemm( + A, B, Bs, group_size, use_deep_gemm_e8m0 + ) + + +def _flashinfer_fp8_blockscale_gemm_impl( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, +) -> torch.Tensor: + return flashinfer_fp8_blockscale_gemm( + input=input, + weight=weight, + weight_scale=weight_scale, + out_dtype=torch.bfloat16, + ) + + +def _flashinfer_fp8_blockscale_gemm_fake( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, +) -> torch.Tensor: + """ + Required fake/meta implementation for torch.compile graph tracing. + """ + return torch.empty( + input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device + ) + + +direct_register_custom_op( + "flashinfer_fp8_blockscale_gemm", + _flashinfer_fp8_blockscale_gemm_impl, + fake_impl=_flashinfer_fp8_blockscale_gemm_fake, +) + + +def _dynamic_flashinfer_deepgemm_blockscale_gemm_impl( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + group_size: int, + use_deep_gemm_e8m0: bool, +) -> torch.Tensor: + """ + Conditional FlashInfer FP8 blockscale GEMM with batch-size-dependent selection. + + This function switches between two optimized kernels based on the input batch size: + - For small batches (M < 32): Uses FlashInfer's DeepGEMM swapAB optimization. + - For larger batches (M >= 32): Uses the official DeepGEMM kernel. + + The conditional logic must use torch.cond() instead of a simple if-else statement + to maintain compatibility with torch.compile graph compilation. + + This batch-size-dependent selection is essential for maintaining model accuracy. + Benchmarks on GSM8K show a significant accuracy gap (88% vs 95%) for DeepSeek-V3.1 + when using FlashInfer's DeepGEMM on M>=32. The M < 32 strategy fixes the accuracy + drop. + + Args: + input: Input tensor of shape (batch_size, input_dim) in FP8 format + weight: Weight tensor of shape (output_dim, input_dim) in FP8 format + weight_scale: Scale factors for weight quantization (per-group) + group_size: Quantization group size for the weight tensor + use_deep_gemm_e8m0: Whether to use the E8M0 format in DeepGEMM quantization + + Returns: + Output tensor of shape (batch_size, output_dim) in bfloat16 format + """ + + def run_flashinfer_deepgemm_swapAB( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + return flashinfer_fp8_blockscale_gemm( + input=input, + weight=weight, + weight_scale=weight_scale, + out_dtype=torch.bfloat16, + ) + + def run_deepgemm( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + q_input, input_scale = per_token_group_quant_fp8( + input, + group_size=group_size, + column_major_scales=True, + use_ue8m0=use_deep_gemm_e8m0, + ) + output = torch.empty( + (q_input.shape[0], weight.shape[0]), + dtype=torch.bfloat16, + device=q_input.device, + ) + fp8_gemm_nt( + (q_input, input_scale), + (weight, weight_scale), + output, + is_deep_gemm_e8m0_used=use_deep_gemm_e8m0, + ) + return output + + if envs.VLLM_BATCH_INVARIANT: + return run_deepgemm(input, weight, weight_scale) + + condition = input.shape[0] < 32 + + # PyTorch's torch.compile cannot handle input-dependent control flow in standard + # Python conditionals. torch.cond() explicitly registers both code paths in the + # computation graph, allowing torch.compile to capture both branches. + # without torch.cond, the M < 32 condition won't be able to be captured by torch + # compile + return torch.cond( + condition, + run_flashinfer_deepgemm_swapAB, + run_deepgemm, + (input, weight, weight_scale), + ) + + +def _dynamic_flashinfer_deepgemm_blockscale_gemm_fake( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + group_size: int, + use_deep_gemm_e8m0: bool, +) -> torch.Tensor: + """ + Required fake/meta implementation for torch.compile graph tracing. + """ + return torch.empty( + input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device + ) + + +direct_register_custom_op( + "dynamic_flashinfer_deepgemm_blockscale_gemm", + _dynamic_flashinfer_deepgemm_blockscale_gemm_impl, + fake_impl=_dynamic_flashinfer_deepgemm_blockscale_gemm_fake, +) diff --git a/vllm/model_executor/kernels/linear/scaled_mm/triton.py b/vllm/model_executor/kernels/linear/scaled_mm/triton.py index c68638a6ad96..bb4392a6de86 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/triton.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/triton.py @@ -13,7 +13,11 @@ convert_to_channelwise, ) from vllm.platforms import current_platform +from vllm.utils.torch_utils import direct_register_custom_op +from .BlockScaledMMLinearKernel import ( + Fp8BlockScaledMMLinearKernel, +) from .cutlass import CutlassInt8ScaledMMLinearKernel from .ScaledMMLinearKernel import ( Int8ScaledMMLinearLayerConfig, @@ -150,3 +154,67 @@ def apply_weights( out -= (x_s * w_s_row * azp_adj).to(x.dtype) return out + + +class TritonFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel): + @classmethod + def is_supported(cls, compute_capability=None): + if not current_platform.is_cuda_alike(): + return False, "only cuda like devices are supported." + return True, None + + def apply_block_scaled_mm( + self, + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + ) -> torch.Tensor: + return torch.ops.vllm.w8a8_triton_block_scaled_mm_func( + A, + B, + As, + Bs, + list(self.weight_group_shape), + self.config.out_dtype, + ) + + +# TODO we should be able to change the type of block_size to GroupShape +# after we resolve GroupShape compilation issue +# https://github.com/vllm-project/vllm/issues/25270 +def _w8a8_triton_block_scaled_mm_func( + qx: torch.Tensor, + weight: torch.Tensor, + x_scale: torch.Tensor, + weight_scale: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + w8a8_triton_block_scaled_mm, + ) + + return w8a8_triton_block_scaled_mm( + qx, weight, x_scale, weight_scale, block_size, output_dtype + ) + + +def _w8a8_triton_block_scaled_mm_fake( + qx: torch.Tensor, + weight: torch.Tensor, + x_scale: torch.Tensor, + weight_scale: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + return torch.empty( + (qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device + ) + + +direct_register_custom_op( + "w8a8_triton_block_scaled_mm_func", + _w8a8_triton_block_scaled_mm_func, + fake_impl=_w8a8_triton_block_scaled_mm_fake, +) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 9c4914e68778..c6b810eb9679 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -8,6 +8,7 @@ from torch.nn import Parameter from vllm._aiter_ops import rocm_aiter_ops +from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.kernels.linear import ( init_fp8_linear_kernel, @@ -16,18 +17,16 @@ CompressedTensorsScheme, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - W8A8BlockFp8LinearOp, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, - maybe_post_process_fp8_weight_block, - process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy, process_fp8_weight_tensor_strategy, validate_fp8_block_shape, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, + create_fp8_quant_key, kFp8DynamicTokenSym, kFp8StaticTensorSym, kFp8StaticTokenSym, @@ -67,6 +66,7 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) self.weight_quant = weight_quant self.strategy = weight_quant.strategy self.out_dtype = torch.get_default_dtype() + self.input_dtype = get_current_vllm_config().model_config.dtype self.is_static_input_scheme = is_static_input_scheme self.weight_block_size = self.weight_quant.block_structure @@ -75,21 +75,18 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled() assert not self.is_static_input_scheme self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) - self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( - weight_group_shape=GroupShape(*self.weight_block_size), - act_quant_group_shape=self.act_q_group_shape, - cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, - use_aiter_and_is_supported=self.use_aiter_and_is_supported, + + self.weight_quant_key = create_fp8_quant_key( + static=True, group_shape=GroupShape(*self.weight_block_size) ) - else: - activation_quant_key = activation_quant_key_mapping[is_static_input_scheme] - weight_quant_key = weight_quant_key_mapping[self.strategy] - self.fp8_linear = init_fp8_linear_kernel( - activation_quant_key=activation_quant_key, - weight_quant_key=weight_quant_key, - out_dtype=self.out_dtype, - module_name=self.__class__.__name__, + self.activation_quant_key = create_fp8_quant_key( + static=False, group_shape=self.act_q_group_shape ) + else: + self.activation_quant_key = activation_quant_key_mapping[ + is_static_input_scheme + ] + self.weight_quant_key = weight_quant_key_mapping[self.strategy] @classmethod def get_min_capability(cls) -> int: @@ -146,6 +143,15 @@ def create_weights( input_scale = create_fp8_input_scale(output_partition_sizes, weight_loader) layer.register_parameter("input_scale", input_scale) + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=self.activation_quant_key, + weight_quant_key=self.weight_quant_key, + weight_shape=layer.weight.shape, + input_dtype=self.input_dtype, + out_dtype=self.out_dtype, + module_name=self.__class__.__name__, + ) + def process_weights_after_loading(self, layer) -> None: if self.strategy == QuantizationStrategy.TENSOR: weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy( @@ -163,10 +169,12 @@ def process_weights_after_loading(self, layer) -> None: elif self.strategy == QuantizationStrategy.BLOCK: assert self.is_static_input_scheme is False - weight, weight_scale = process_fp8_weight_block_strategy( - layer.weight, layer.weight_scale - ) - input_scale = None + self.fp8_linear.process_weights_after_loading(layer) + + layer.input_scale = None + # fp8_linear.process_weights_after_loading applies the post process + # and reassigns the weight and weight_scale buffers to layer attributes. + return else: raise ValueError( @@ -185,8 +193,6 @@ def process_weights_after_loading(self, layer) -> None: layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) else: layer.input_scale = None - if self.strategy == QuantizationStrategy.BLOCK: - maybe_post_process_fp8_weight_block(layer) if hasattr(self, "fp8_linear"): self.fp8_linear.process_weights_after_loading(layer) @@ -197,13 +203,4 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - if self.weight_block_size is not None: - return self.w8a8_block_fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=layer.input_scale, - bias=bias, - ) - return self.fp8_linear.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index c952b7690846..377cec364bfc 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -7,6 +7,7 @@ from torch.nn import Module from torch.nn.parameter import Parameter +from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.kernels.linear import ( init_fp8_linear_kernel, @@ -93,12 +94,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: FBGEMMFp8Config): self.quant_config = quant_config self.out_dtype = torch.get_default_dtype() - self.fp8_linear = init_fp8_linear_kernel( - activation_quant_key=kFp8DynamicTokenSym, - weight_quant_key=kFp8StaticTokenSym, - out_dtype=torch.get_default_dtype(), - module_name=self.__class__.__name__, - ) + self.input_dtype = get_current_vllm_config().model_config.dtype def create_weights( self, @@ -149,6 +145,15 @@ def create_weights( ) layer.input_scale_ub = input_scale_ub + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=kFp8DynamicTokenSym, + weight_quant_key=kFp8StaticTokenSym, + weight_shape=layer.weight.shape, + input_dtype=self.input_dtype, + out_dtype=self.out_dtype, + module_name=self.__class__.__name__, + ) + def process_weights_after_loading(self, layer: Module) -> None: # required by torch.compile layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 2816f7656504..dfb09d57361e 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -10,7 +10,7 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops -from vllm._aiter_ops import rocm_aiter_ops +from vllm.config import get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.kernels.linear import ( @@ -45,13 +45,10 @@ ) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - W8A8BlockFp8LinearOp, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, - maybe_post_process_fp8_weight_block, process_fp8_input_tensor_strategy_moe, - process_fp8_weight_block_strategy, process_fp8_weight_tensor_strategy, process_fp8_weight_tensor_strategy_moe, validate_fp8_block_shape, @@ -61,6 +58,7 @@ ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, + create_fp8_quant_key, is_layer_skipped, kFp8Dynamic128Sym, kFp8DynamicTensorSym, @@ -273,12 +271,13 @@ def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() self.out_dtype = torch.get_default_dtype() + self.input_dtype = get_current_vllm_config().model_config.dtype # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization self.marlin_input_dtype = None + self.use_marlin = False - self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled() if self.quant_config.use_deep_gemm is not None: self.use_deep_gemm = self.quant_config.use_deep_gemm else: @@ -288,37 +287,26 @@ def __init__(self, quant_config: Fp8Config): self.block_quant = self.weight_block_size is not None self.act_q_static = self.quant_config.activation_scheme == "static" - # Use per-token quantization for better perf if dynamic and cutlass - if self.act_q_static: - activation_quant_key = kFp8StaticTensorSym - elif cutlass_fp8_supported(): - activation_quant_key = kFp8DynamicTokenSym - else: - activation_quant_key = kFp8DynamicTensorSym - if self.block_quant: - weight_quant_key = kFp8Static128BlockSym - else: - weight_quant_key = kFp8StaticTensorSym - - self.fp8_linear = init_fp8_linear_kernel( - activation_quant_key=activation_quant_key, - weight_quant_key=weight_quant_key, - out_dtype=torch.get_default_dtype(), - module_name=self.__class__.__name__, - ) - self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel) - - if self.block_quant and not self.use_marlin: assert not self.act_q_static assert self.weight_block_size is not None - self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( - weight_group_shape=GroupShape(*self.weight_block_size), - act_quant_group_shape=GroupShape(1, self.weight_block_size[0]), - cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, - use_aiter_and_is_supported=self.use_aiter_and_is_supported, - use_deep_gemm=self.use_deep_gemm, + + self.activation_quant_key = create_fp8_quant_key( + static=self.act_q_static, + group_shape=GroupShape(1, self.weight_block_size[0]), + ) + self.weight_quant_key = create_fp8_quant_key( + static=True, group_shape=GroupShape(*self.weight_block_size) ) + else: + self.weight_quant_key = kFp8StaticTensorSym + # Use per-token quantization for better perf if dynamic and cutlass + if self.act_q_static: + self.activation_quant_key = kFp8StaticTensorSym + elif cutlass_fp8_supported(): + self.activation_quant_key = kFp8DynamicTokenSym + else: + self.activation_quant_key = kFp8DynamicTensorSym def create_weights( self, @@ -384,6 +372,17 @@ def create_weights( set_weight_attrs(scale, {"scale_type": "input_scale"}) layer.register_parameter("input_scale", scale) + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=self.activation_quant_key, + weight_quant_key=self.weight_quant_key, + weight_shape=layer.weight.shape, + input_dtype=self.input_dtype, + out_dtype=self.out_dtype, + module_name=self.__class__.__name__, + ) + + self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel) + def process_weights_after_loading(self, layer: Module) -> None: if self.use_marlin: # Only Marlin kernels support `marlin_input_dtype`; guard to avoid @@ -398,13 +397,7 @@ def process_weights_after_loading(self, layer: Module) -> None: if self.block_quant: assert not self.act_q_static - weight, weight_scale_inv = process_fp8_weight_block_strategy( - layer.weight, layer.weight_scale_inv - ) - - # Update layer with new values - replace_parameter(layer, "weight", weight.data) - replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data) + self.fp8_linear.process_weights_after_loading(layer) # If checkpoint not serialized fp8, quantize the weights. else: @@ -435,9 +428,6 @@ def process_weights_after_loading(self, layer: Module) -> None: else: layer.input_scale = None - if self.block_quant and self.use_deep_gemm: - maybe_post_process_fp8_weight_block(layer) - def apply( self, layer: torch.nn.Module, @@ -449,12 +439,10 @@ def apply( if envs.VLLM_BATCH_INVARIANT: if self.block_quant: assert self.weight_block_size is not None - return self.w8a8_block_fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale_inv, - input_scale=layer.input_scale, - bias=bias, + return self.fp8_linear.apply_weights( + layer, + x, + bias, ) else: # per-tensor/channel: dequant to BF16 and run GEMM @@ -483,17 +471,6 @@ def apply( if self.use_marlin: return self.fp8_linear.apply_weights(layer, x, bias) - if self.block_quant: - assert self.weight_block_size is not None - - return self.w8a8_block_fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale_inv, - input_scale=layer.input_scale, - bias=bias, - ) - return self.fp8_linear.apply_weights(layer, x, bias) @@ -538,6 +515,24 @@ def create_weights( initialize_online_processing(layer) + # TODO: remove this check once the following RFC is resolved. + # https://github.com/vllm-project/vllm/issues/33314 + # This check is required because Mxfp8OnlineLinearMethod inherits from + # Fp8OnlineLinearMethod but only calls super().create_weights(), so we must + # skip the fp8_linear kernel creation. + if hasattr(self, "mxfp8_linear"): + return + + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=self.activation_quant_key, + weight_quant_key=self.weight_quant_key, + weight_shape=layer.weight.shape, + input_dtype=self.input_dtype, + out_dtype=self.out_dtype, + module_name=self.__class__.__name__, + ) + self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel) + def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_already_called_process_weights_after_loading", False): return diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 53b15950d922..ad188c665e95 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.kernels.linear import init_fp8_linear_kernel from vllm.model_executor.layers.attention import Attention, MLAAttention @@ -56,7 +57,6 @@ swap_w13_to_w31, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - W8A8BlockFp8LinearOp, process_fp8_input_tensor_strategy_moe, process_fp8_weight_tensor_strategy_moe, ) @@ -78,6 +78,7 @@ ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, + create_fp8_quant_key, is_layer_skipped, kFp8DynamicTokenSym, kFp8StaticTensorSym, @@ -86,7 +87,6 @@ kNvfp4Static, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - cutlass_block_fp8_supported, requantize_with_max_scale, ) from vllm.model_executor.parameter import ( @@ -450,12 +450,8 @@ class ModelOptFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config - self.fp8_linear = init_fp8_linear_kernel( - activation_quant_key=kFp8StaticTensorSym, - weight_quant_key=kFp8StaticTensorSym, - out_dtype=torch.get_default_dtype(), - module_name=self.__class__.__name__, - ) + self.out_dtype = torch.get_default_dtype() + self.input_dtype = get_current_vllm_config().model_config.dtype def create_weights( self, @@ -505,6 +501,15 @@ def create_weights( scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", scale) + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=kFp8StaticTensorSym, + weight_quant_key=kFp8StaticTensorSym, + weight_shape=layer.weight.shape, + input_dtype=self.input_dtype, + out_dtype=self.out_dtype, + module_name=self.__class__.__name__, + ) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: weight = layer.weight max_w_scale = layer.weight_scale.max() @@ -536,12 +541,8 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config - self.fp8_linear = init_fp8_linear_kernel( - activation_quant_key=kFp8DynamicTokenSym, - weight_quant_key=kFp8StaticTokenSym, - out_dtype=torch.get_default_dtype(), - module_name=self.__class__.__name__, - ) + self.out_dtype = torch.get_default_dtype() + self.input_dtype = get_current_vllm_config().model_config.dtype def create_weights( self, @@ -587,6 +588,15 @@ def create_weights( weight_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", weight_scale) + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=kFp8DynamicTokenSym, + weight_quant_key=kFp8StaticTokenSym, + weight_shape=layer.weight.shape, + input_dtype=self.input_dtype, + out_dtype=self.out_dtype, + module_name=self.__class__.__name__, + ) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight = Parameter(layer.weight.t(), requires_grad=False) layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) @@ -616,12 +626,16 @@ def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config block_n, block_k = self._WEIGHT_BLOCK_SIZE self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE) - self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( - weight_group_shape=GroupShape(block_n, block_k), - act_quant_group_shape=GroupShape(1, block_k), - cutlass_block_fp8_supported=cutlass_block_fp8_supported(), - use_aiter_and_is_supported=False, + + self.activation_quant_key = create_fp8_quant_key( + static=False, group_shape=GroupShape(1, block_k) ) + self.weight_quant_key = create_fp8_quant_key( + static=True, group_shape=GroupShape(block_n, block_k) + ) + + self.out_dtype = torch.get_default_dtype() + self.input_dtype = get_current_vllm_config().model_config.dtype def create_weights( self, @@ -688,8 +702,17 @@ def create_weights( weight_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", weight_scale) + self.w8a8_block_fp8_linear = init_fp8_linear_kernel( + activation_quant_key=self.activation_quant_key, + weight_quant_key=self.weight_quant_key, + weight_shape=layer.weight.shape, + input_dtype=self.input_dtype, + out_dtype=self.out_dtype, + module_name=self.__class__.__name__, + ) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # Keep weight in [out, in] layout for W8A8BlockFp8LinearOp. + # Keep weight in [out, in] layout for Fp8BlockScaledMMLinearKernel. layer.weight = Parameter(layer.weight.data, requires_grad=False) scale = layer.weight_scale @@ -713,13 +736,7 @@ def apply( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.w8a8_block_fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=None, - bias=bias, - ) + return self.w8a8_block_fp8_linear.apply_weights(layer, x, bias) class ModelOptFp8MoEMethod(FusedMoEMethodBase): diff --git a/vllm/model_executor/layers/quantization/online/fp8.py b/vllm/model_executor/layers/quantization/online/fp8.py index 941ae25b173d..fa8cf240627b 100644 --- a/vllm/model_executor/layers/quantization/online/fp8.py +++ b/vllm/model_executor/layers/quantization/online/fp8.py @@ -17,7 +17,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops -from vllm._aiter_ops import rocm_aiter_ops +from vllm.config import get_current_vllm_config from vllm.model_executor.kernels.linear import init_fp8_linear_kernel from vllm.model_executor.layers.fused_moe import ( FusedMoEMethodBase, @@ -28,13 +28,9 @@ from vllm.model_executor.layers.linear import ( LinearMethodBase, ) -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - W8A8BlockFp8LinearOp, - maybe_post_process_fp8_weight_block, - process_fp8_weight_block_strategy, -) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, + create_fp8_quant_key, kFp8Dynamic128Sym, kFp8DynamicTensorSym, kFp8DynamicTokenSym, @@ -42,7 +38,6 @@ kFp8StaticTensorSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - cutlass_block_fp8_supported, cutlass_fp8_supported, ) from vllm.model_executor.model_loader.reload.layerwise import ( @@ -51,7 +46,7 @@ from vllm.model_executor.parameter import ModelWeightParameter from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.platforms import current_platform -from vllm.utils.deep_gemm import is_deep_gemm_supported, per_block_cast_to_fp8 +from vllm.utils.deep_gemm import per_block_cast_to_fp8 # --------------------------------------------------------------------------- # Online FP8 Linear Methods @@ -64,6 +59,10 @@ class _Fp8OnlineLinearBase(LinearMethodBase): uses_meta_device: bool = True + def __init__(self): + self.out_dtype = torch.get_default_dtype() + self.input_dtype = get_current_vllm_config().model_config.dtype + def create_weights( self, layer: torch.nn.Module, @@ -103,18 +102,41 @@ class Fp8PerTensorOnlineLinearMethod(_Fp8OnlineLinearBase): Loads fp16/bf16 weights and quantizes them per-tensor during loading.""" def __init__(self): - self.out_dtype = torch.get_default_dtype() + super().__init__() + self.weight_quant_key = kFp8StaticTensorSym # Use per-token quantization for better perf if dynamic and cutlass if cutlass_fp8_supported(): - activation_quant_key = kFp8DynamicTokenSym + self.activation_quant_key = kFp8DynamicTokenSym else: - activation_quant_key = kFp8DynamicTensorSym + self.activation_quant_key = kFp8DynamicTensorSym + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + super().create_weights( + layer, + input_size_per_partition, + output_partition_sizes, + input_size, + output_size, + params_dtype, + **extra_weight_attrs, + ) self.fp8_linear = init_fp8_linear_kernel( - activation_quant_key=activation_quant_key, - weight_quant_key=kFp8StaticTensorSym, - out_dtype=torch.get_default_dtype(), + activation_quant_key=self.activation_quant_key, + weight_quant_key=self.weight_quant_key, + weight_shape=layer.weight.shape, + input_dtype=self.input_dtype, + out_dtype=self.out_dtype, module_name=self.__class__.__name__, ) @@ -166,19 +188,14 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase): Loads fp16/bf16 weights and quantizes them per-block during loading.""" def __init__(self): - self.out_dtype = torch.get_default_dtype() + super().__init__() self.weight_block_size = [128, 128] - - self.use_deep_gemm = is_deep_gemm_supported() - self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled() - self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() - - self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( - weight_group_shape=GroupShape(*self.weight_block_size), - act_quant_group_shape=GroupShape(1, self.weight_block_size[0]), - cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, - use_aiter_and_is_supported=self.use_aiter_and_is_supported, - use_deep_gemm=self.use_deep_gemm, + self.activation_quant_key = create_fp8_quant_key( + static=False, + group_shape=GroupShape(1, self.weight_block_size[0]), + ) + self.weight_quant_key = create_fp8_quant_key( + static=True, group_shape=GroupShape(*self.weight_block_size) ) def create_weights( @@ -202,6 +219,15 @@ def create_weights( ) layer.weight_block_size = self.weight_block_size + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=self.activation_quant_key, + weight_quant_key=self.weight_quant_key, + weight_shape=layer.weight.shape, + input_dtype=self.input_dtype, + out_dtype=self.out_dtype, + module_name=self.__class__.__name__, + ) + def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_already_called_process_weights_after_loading", False): return @@ -213,14 +239,10 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.weight, block_size=block_size, use_ue8m0=False ) - qweight, weight_scale_inv = process_fp8_weight_block_strategy( - qweight, weight_scale_inv - ) - replace_parameter(layer, "weight", qweight.data) replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data) - maybe_post_process_fp8_weight_block(layer) + self.fp8_linear.process_weights_after_loading(layer) # Prevent duplicate processing (e.g., during weight reload) layer._already_called_process_weights_after_loading = True @@ -234,12 +256,10 @@ def apply( assert self.weight_block_size is not None # Note: batch invariance already handled in the function below - return self.w8a8_block_fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale_inv, - input_scale=layer.input_scale, - bias=bias, + return self.fp8_linear.apply_weights( + layer, + x, + bias, ) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 72f050a1245b..3312e6901d6d 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -7,6 +7,7 @@ import torch from torch.nn import Parameter +from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.kernels.linear import ( init_fp8_linear_kernel, @@ -57,6 +58,7 @@ def __init__( kFp8StaticTokenSym if per_token_weight else kFp8StaticTensorSym ) self.out_dtype = torch.get_default_dtype() + self.input_dtype = get_current_vllm_config().model_config.dtype @classmethod def get_min_capability(cls) -> int: @@ -175,7 +177,9 @@ def create_weights( self.fp8_linear = init_fp8_linear_kernel( activation_quant_key=self.activation_quant_key, weight_quant_key=self.weight_quant_key, - out_dtype=torch.get_default_dtype(), + weight_shape=layer.weight.shape, + input_dtype=self.input_dtype, + out_dtype=self.out_dtype, module_name=self.__class__.__name__, ) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 9568d1320bc6..19fdb1ec884d 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -12,15 +12,11 @@ import vllm.envs as envs from vllm import _custom_ops as ops -from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, get_fp8_min_max, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - CUTLASS_BLOCK_FP8_SUPPORTED, all_close_1d, per_tensor_dequantize, ) @@ -29,22 +25,14 @@ ChannelQuantScaleParameter, PerTensorScaleParameter, ) -from vllm.model_executor.utils import replace_parameter, set_weight_attrs +from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import ( - fp8_gemm_nt, get_tma_aligned_size, is_deep_gemm_e8m0_used, - is_deep_gemm_supported, - should_use_deepgemm_for_fp8_linear, transform_sf_into_required_layout, ) -from vllm.utils.flashinfer import ( - flashinfer_fp8_blockscale_gemm, - is_flashinfer_fp8_blockscale_gemm_supported, - should_use_flashinfer_for_blockscale_fp8_gemm, -) from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) @@ -56,153 +44,6 @@ def is_fp8(x: torch.dtype | torch.Tensor) -> bool: return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz -# We need to pass in the is_hopper flag as argument because the function -# current_platform.is_device_capability() is not supported by Torch compiler. -def cutlass_scaled_mm( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - block_size: list[int], - output_dtype: torch.dtype = torch.float16, -) -> torch.Tensor: - return ops.cutlass_scaled_mm( - A, - B.T, - out_dtype=output_dtype, - scale_a=As, - scale_b=Bs.T, - ) - - -# TODO we should be able to change the type of block_size to GroupShape -# after we resolve GroupShape compilation issue -# https://github.com/vllm-project/vllm/issues/25270 -def _w8a8_triton_block_scaled_mm_func( - qx: torch.Tensor, - weight: torch.Tensor, - x_scale: torch.Tensor, - weight_scale: torch.Tensor, - block_size: list[int], - output_dtype: torch.dtype, -) -> torch.Tensor: - return w8a8_triton_block_scaled_mm( - qx, weight, x_scale, weight_scale, block_size, output_dtype - ) - - -def _w8a8_triton_block_scaled_mm_fake( - qx: torch.Tensor, - weight: torch.Tensor, - x_scale: torch.Tensor, - weight_scale: torch.Tensor, - block_size: list[int], - output_dtype: torch.dtype, -) -> torch.Tensor: - return torch.empty( - (qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device - ) - - -direct_register_custom_op( - "w8a8_triton_block_scaled_mm_func", - _w8a8_triton_block_scaled_mm_func, - fake_impl=_w8a8_triton_block_scaled_mm_fake, -) - - -def _padded_cutlass( - qx: torch.Tensor, - weight: torch.Tensor, - x_scale: torch.Tensor, - weight_scale: torch.Tensor, - block_size: list[int], - output_dtype: torch.dtype, -) -> torch.Tensor: - pad_multiple = 4 - dim = qx.shape[0] - padded = ( - dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple) - ) - - has_pad = padded > dim - - if has_pad: - padded_shape = [padded, *qx.shape[1:]] - padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype) - padded_qx[0 : qx.shape[0], ...].copy_(qx) - - padded_x_scale_shape = [*x_scale.shape[1:], padded] - padded_x_scale = torch.ones( - padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype - ).permute(-1, -2) - padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale) - - output = cutlass_scaled_mm( - padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype - ) - return output[0 : qx.shape[0], ...] - else: - return cutlass_scaled_mm( - qx, weight, x_scale, weight_scale, block_size, output_dtype - ) - - -def _padded_cutlass_fake( - qx: torch.Tensor, - weight: torch.Tensor, - x_scale: torch.Tensor, - weight_scale: torch.Tensor, - block_size: list[int], - output_dtype: torch.dtype, -) -> torch.Tensor: - return torch.empty( - (qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device - ) - - -direct_register_custom_op( - "padded_cutlass", - _padded_cutlass, - fake_impl=_padded_cutlass_fake, -) - - -def _fp8_gemm_nt_op( - q_input: torch.Tensor, - input_scale: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - output: torch.Tensor, - use_deep_gemm_e8m0: bool, -) -> None: - fp8_gemm_nt( - (q_input, input_scale), - (weight, weight_scale), - output, - is_deep_gemm_e8m0_used=use_deep_gemm_e8m0, - ) - - -def _fp8_gemm_nt_op_fake( - q_input: torch.Tensor, - input_scale: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - output: torch.Tensor, - use_deep_gemm_e8m0: bool, -) -> None: - return None - - -direct_register_custom_op( - "fp8_gemm_nt_op", - _fp8_gemm_nt_op, - mutates_args=["output"], - fake_impl=_fp8_gemm_nt_op_fake, -) - - def _triton_per_token_group_quant_fp8_impl( x: torch.Tensor, group_size: int, @@ -236,362 +77,6 @@ def _triton_per_token_group_quant_fp8_fake( ) -def _flashinfer_fp8_blockscale_gemm_impl( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - group_size: int, - use_deep_gemm_e8m0: bool, -) -> torch.Tensor: - """ - Conditional FlashInfer FP8 blockscale GEMM with batch-size-dependent selection. - - This function switches between two optimized kernels based on the input batch size: - - For small batches (M < 32): Uses FlashInfer's DeepGEMM swapAB optimization. - - For larger batches (M >= 32): Uses the official DeepGEMM kernel. - - The conditional logic must use torch.cond() instead of a simple if-else statement - to maintain compatibility with torch.compile graph compilation. - - This batch-size-dependent selection is essential for maintaining model accuracy. - Benchmarks on GSM8K show a significant accuracy gap (88% vs 95%) for DeepSeek-V3.1 - when using FlashInfer's DeepGEMM on M>=32. The M < 32 strategy fixes the accuracy - drop. - - Args: - input: Input tensor of shape (batch_size, input_dim) in FP8 format - weight: Weight tensor of shape (output_dim, input_dim) in FP8 format - weight_scale: Scale factors for weight quantization (per-group) - group_size: Quantization group size for the weight tensor - use_deep_gemm_e8m0: Whether to use the E8M0 format in DeepGEMM quantization - - Returns: - Output tensor of shape (batch_size, output_dim) in bfloat16 format - """ - - def run_flashinfer_deepgemm_swapAB( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - ) -> torch.Tensor: - return flashinfer_fp8_blockscale_gemm( - input=input, - weight=weight, - weight_scale=weight_scale, - out_dtype=torch.bfloat16, - ) - - def run_deepgemm( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - ) -> torch.Tensor: - q_input, input_scale = per_token_group_quant_fp8( - input, - group_size=group_size, - column_major_scales=True, - use_ue8m0=use_deep_gemm_e8m0, - ) - output = torch.empty( - (q_input.shape[0], weight.shape[0]), - dtype=torch.bfloat16, - device=q_input.device, - ) - fp8_gemm_nt( - (q_input, input_scale), - (weight, weight_scale), - output, - is_deep_gemm_e8m0_used=use_deep_gemm_e8m0, - ) - return output - - if envs.VLLM_BATCH_INVARIANT: - return run_deepgemm(input, weight, weight_scale) - - condition = input.shape[0] < 32 - - # PyTorch's torch.compile cannot handle input-dependent control flow in standard - # Python conditionals. torch.cond() explicitly registers both code paths in the - # computation graph, allowing torch.compile to capture both branches. - # without torch.cond, the M < 32 condition won't be able to be captured by torch - # compile - return torch.cond( - condition, - run_flashinfer_deepgemm_swapAB, - run_deepgemm, - (input, weight, weight_scale), - ) - - -def _flashinfer_fp8_blockscale_gemm_fake( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - group_size: int, - use_deep_gemm_e8m0: bool, -) -> torch.Tensor: - """ - Required fake/meta implementation for torch.compile graph tracing. - """ - return torch.empty( - input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device - ) - - -direct_register_custom_op( - "flashinfer_fp8_blockscale_gemm", - _flashinfer_fp8_blockscale_gemm_impl, - fake_impl=_flashinfer_fp8_blockscale_gemm_fake, -) - - -# TODO fix ROCm->Triton custom path: -# https://github.com/vllm-project/vllm/issues/14397 -class W8A8BlockFp8LinearOp: - """ - This class executes a Blocked FP8 linear layer using cutlass if supported - and torch.scaled_mm otherwise. - """ - - def __init__( - self, - weight_group_shape: GroupShape, - act_quant_group_shape: GroupShape, - cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, - use_aiter_and_is_supported: bool = False, - use_deep_gemm: bool | None = None, - ): - self.weight_group_shape = weight_group_shape - self.act_quant_group_shape = act_quant_group_shape - if use_deep_gemm is not None: - self.is_deep_gemm_supported = use_deep_gemm - else: - self.is_deep_gemm_supported = is_deep_gemm_supported() - self.is_hopper = current_platform.is_device_capability(90) - self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used() - self.is_flashinfer_supported = is_flashinfer_fp8_blockscale_gemm_supported() - - # Get the correct blockscale mul and input quant operations. - # We can't use _dispatch_w8a8_blockscale_op to figure out if we want - # to use deepgemm because we don't know the shape of weights (and - # whether deepgemm supports it) at the init time. - self.w8a8_blockscale_op, self.input_quant_op = ( - self._dispatch_w8a8_blockscale_op( - cutlass_block_fp8_supported, use_aiter_and_is_supported - ) - ) - self.deepgemm_input_quant_op = ( - QuantFP8( - False, - self.act_quant_group_shape, - column_major_scales=True, - tma_aligned_scales=envs.VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES, - use_ue8m0=self.use_deep_gemm_e8m0, - ) - if self.is_deep_gemm_supported - else None - ) - - def apply( - self, - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - input_scale: torch.Tensor | None = None, - bias: torch.Tensor | None = None, - ) -> torch.Tensor: - assert input_scale is None - # View input as 2D matrix for fp8 methods - input_2d = input.view(-1, input.shape[-1]) - output_shape = [*input.shape[:-1], weight.shape[0]] - output_dtype = input.dtype - - if should_use_flashinfer_for_blockscale_fp8_gemm( - self.is_flashinfer_supported, output_dtype, input_2d, weight - ) and should_use_deepgemm_for_fp8_linear( - output_dtype, weight, self.is_deep_gemm_supported - ): - output = self._run_flashinfer(input_2d, weight, weight_scale) - - elif should_use_deepgemm_for_fp8_linear( - output_dtype, weight, self.is_deep_gemm_supported - ): - output = self._run_deepgemm(input_2d, weight, weight_scale) - else: - output = self.w8a8_blockscale_op( - input_2d, weight, weight_scale, input_scale - ) - - if bias is not None: - output = output + bias - return output.to(dtype=input.dtype).view(*output_shape) - - def _run_deepgemm( - self, - input_2d: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - ) -> torch.Tensor: - assert self.deepgemm_input_quant_op is not None - q_input, input_scale = self.deepgemm_input_quant_op(input_2d) - output = torch.empty( - (q_input.shape[0], weight.shape[0]), - dtype=torch.bfloat16, - device=q_input.device, - ) - torch.ops.vllm.fp8_gemm_nt_op( - q_input, input_scale, weight, weight_scale, output, self.use_deep_gemm_e8m0 - ) - return output - - def _run_cutlass( - self, - input_2d: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - input_scale: torch.Tensor | None = None, - ) -> torch.Tensor: - assert input_scale is None - assert self.input_quant_op is not None - q_input, input_scale = self.input_quant_op(input_2d) - if self.is_hopper: - return torch.ops.vllm.padded_cutlass( - q_input, - weight, - input_scale, - weight_scale, - list(self.weight_group_shape), - input_2d.dtype, - ) - else: - return cutlass_scaled_mm( - q_input, - weight, - input_scale, - weight_scale, - list(self.weight_group_shape), - input_2d.dtype, - ) - - def _run_aiter( - self, - input_2d: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - input_scale: torch.Tensor | None = None, - ) -> torch.Tensor: - assert self.act_quant_group_shape == GroupShape(1, 128) - - n, k = weight.shape - - use_triton = ( - not current_platform.is_fp8_fnuz() - and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k) - ) - - if use_triton: - gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale - else: - gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale - - if input_scale is not None: - q_input = input_2d - else: - q_input, input_scale = self.input_quant_op(input_2d, use_triton=use_triton) - - return gemm_a8w8_blockscale_op( - q_input, - weight, - input_scale, - weight_scale, - list(self.weight_group_shape), - output_dtype=input_2d.dtype, - ) - - def _run_triton( - self, - input_2d: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - input_scale: torch.Tensor | None = None, - ) -> torch.Tensor: - assert input_scale is None - assert self.input_quant_op is not None - q_input, input_scale = self.input_quant_op(input_2d) - return torch.ops.vllm.w8a8_triton_block_scaled_mm_func( - q_input, - weight, - input_scale, - weight_scale, - list(self.weight_group_shape), - input_2d.dtype, - ) - - def _run_flashinfer( - self, - input_2d: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - ) -> torch.Tensor: - """ - Run FlashInfer FP8 block-scale GEMM. - - This backend uses TensorRT-LLM's FP8 block-scale GEMM kernels - and supports FP8+FP8 (W8A8 full quantization) on SM90+ (Hopper). - """ - # Now call FlashInfer with BF16 input + FP8 weight, input will be - # quantized with FlashInfer kernel (W8A8) - output = torch.ops.vllm.flashinfer_fp8_blockscale_gemm( - input=input_2d, # BF16 input - weight=weight, # FP8 weight - weight_scale=weight_scale, # Weight scales - group_size=self.act_quant_group_shape.col, - use_deep_gemm_e8m0=self.use_deep_gemm_e8m0, - ) - return output - - def _dispatch_w8a8_blockscale_op( - self, - use_cutlass: bool, - use_aiter_and_is_supported: bool, - ) -> tuple[ - Callable[ - [ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor | None, - ], - torch.Tensor, - ], - QuantFP8, - ]: - if use_cutlass: - return self._run_cutlass, ( - QuantFP8( - False, - self.act_quant_group_shape, - column_major_scales=True, - use_ue8m0=False, - ) - ) - if use_aiter_and_is_supported: - return self._run_aiter, QuantFP8( - False, - self.act_quant_group_shape, - column_major_scales=False, - use_ue8m0=False, - ) - return self._run_triton, ( - QuantFP8( - False, - self.act_quant_group_shape, - column_major_scales=False, - use_ue8m0=False, - ) - ) - - def input_to_float8( x: torch.Tensor, dtype: torch.dtype | None = None ) -> tuple[torch.Tensor, torch.Tensor]: @@ -1612,34 +1097,6 @@ def process_fp8_weight_block_strategy( return weight, weight_scale -def maybe_post_process_fp8_weight_block(layer: torch.nn.Module): - assert layer.weight_block_size is not None - - from vllm.utils.deep_gemm import ( - is_deep_gemm_e8m0_used, - should_use_deepgemm_for_fp8_linear, - ) - - # On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to - # requantize the weight and input to the specific scale - # at the same time. - should_use_deepgemm = should_use_deepgemm_for_fp8_linear( - layer.orig_dtype, layer.weight - ) - if should_use_deepgemm: - scale_attr = ( - "weight_scale_inv" if hasattr(layer, "weight_scale_inv") else "weight_scale" - ) - dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block( - wq=layer.weight.data, - ws=getattr(layer, scale_attr).data, - quant_block_shape=tuple(layer.weight_block_size), - use_e8m0=is_deep_gemm_e8m0_used(), - ) - replace_parameter(layer, "weight", dg_weight) - replace_parameter(layer, scale_attr, dg_weight_scale) - - def process_fp8_weight_tensor_strategy_moe( weight: torch.Tensor, weight_scales: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 1170a2d3a77c..d1b1b77988c7 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -171,6 +171,16 @@ def __str__(self): kMxfp4Static = QuantKey(FP4_DTYPE, scale=kMxfp4StaticGroupScale, symmetric=True) +def create_fp8_quant_key( + static: bool, + group_shape: GroupShape, + symmetric: bool = True, + scale_dtype: torch.dtype = torch.float32, +) -> QuantKey: + scale_desc = ScaleDesc(scale_dtype, static, group_shape) + return QuantKey(FP8_DTYPE, scale_desc, symmetric=symmetric) + + # Normalize the group_shape to the full extent for any dims that are -1 def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): # -1 means full extent diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index abf25db16c78..b6a6805a9d2a 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -413,7 +413,7 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): def should_use_deepgemm_for_fp8_linear( output_dtype: torch.dtype, - weight: torch.Tensor, + weight_shape: tuple[int, int], supports_deep_gemm: bool | None = None, ): if supports_deep_gemm is None: @@ -428,8 +428,8 @@ def should_use_deepgemm_for_fp8_linear( return ( supports_deep_gemm and output_dtype == torch.bfloat16 - and weight.shape[0] % N_MULTIPLE == 0 - and weight.shape[1] % K_MULTIPLE == 0 + and weight_shape[0] % N_MULTIPLE == 0 + and weight_shape[1] % K_MULTIPLE == 0 ) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 373134e655df..ed171db96e73 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -748,8 +748,9 @@ def is_flashinfer_fp8_blockscale_gemm_supported() -> bool: def should_use_flashinfer_for_blockscale_fp8_gemm( is_flashinfer_supported: bool, output_dtype: torch.dtype, - input: torch.Tensor, - weight: torch.Tensor, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, + weight_shape: tuple[int, int], ): if not is_flashinfer_supported: return False @@ -760,15 +761,12 @@ def should_use_flashinfer_for_blockscale_fp8_gemm( N_MULTIPLE = 64 K_MULTIPLE = 128 - weight_dtype = weight.dtype - input_dtype = input.dtype - should_use_flashinfer = ( output_dtype == torch.bfloat16 and input_dtype == torch.bfloat16 and weight_dtype == torch.float8_e4m3fn - and weight.shape[0] % N_MULTIPLE == 0 - and weight.shape[1] % K_MULTIPLE == 0 + and weight_shape[0] % N_MULTIPLE == 0 + and weight_shape[1] % K_MULTIPLE == 0 ) return should_use_flashinfer