From 21d7d67c4065de986d6cec9d97ef19f1333c5722 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Sat, 6 Sep 2025 14:35:13 -0700 Subject: [PATCH 001/137] Functionalized patterns in prep for utility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/compilation/fusion.py | 62 +++++++++++++++----------------------- 1 file changed, 24 insertions(+), 38 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index df54e94a03db..71a3153bf0bc 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -112,13 +112,13 @@ def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): def register(self, pm_pass: PatternMatcherPass): # Cannot use methods, as the self argument affects tracing - def pattern( - result: torch.Tensor, - result_rms: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): + def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + result_rms = torch.empty_like(input) + # TODO: why does empty_like produce a permute but + # empty via shape doesn't? + result = torch.empty( + input.shape, device=input.device, dtype=self.quant_dtype + ) at1 = auto_functionalized( RMS_OP, result=result_rms, @@ -133,13 +133,8 @@ def pattern( # result return at2[1] - def replacement( - result: torch.Tensor, - result_rms: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): + def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + result = torch.empty_like(input, dtype=self.quant_dtype) at = auto_functionalized( self.FUSED_OP, result=result, @@ -153,8 +148,6 @@ def replacement( return at[1] inputs = [ - torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result - empty_bf16(5, 4), # result_rms empty_bf16(5, 4), # input empty_bf16(1, 5), # weight empty_fp32(1, 1), # scale @@ -175,12 +168,14 @@ def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): def register(self, pm_pass: PatternMatcherPass): def pattern( - result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): + result = torch.empty( + input.shape, device=input.device, dtype=self.quant_dtype + ) at = auto_functionalized( RMS_ADD_OP, input=input, @@ -196,12 +191,12 @@ def pattern( return at1[1], at[2] def replacement( - result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): + result = torch.empty_like(input, dtype=self.quant_dtype) at = auto_functionalized( self.FUSED_OP, result=result, @@ -216,7 +211,6 @@ def replacement( return at[1], at[2] inputs = [ - torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result empty_bf16(5, 4), # input empty_bf16(5, 4), # residual empty_bf16(1, 5), # weight @@ -248,13 +242,11 @@ def __init__( super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - def pattern( - result: torch.Tensor, - result_rms: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): + def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + result_rms = torch.empty_like(input) + result = torch.empty( + input.shape, device=input.device, dtype=self.quant_dtype + ) at1 = auto_functionalized( RMS_OP, result=result_rms, @@ -269,13 +261,8 @@ def pattern( # result, scale return at2[1], at2[2] - def replacement( - result: torch.Tensor, - result_rms: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): + def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + result = torch.empty_like(input, dtype=self.quant_dtype) at = auto_functionalized( self.FUSED_OP, result=result, @@ -291,8 +278,6 @@ def replacement( return at[1], at[2] inputs = [ - torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result - empty_bf16(5, 4), # result_rms empty_bf16(5, 4), # input empty_bf16(1, 5), # weight empty_fp32(1, 1), # scale @@ -324,12 +309,14 @@ def __init__( def register(self, pm_pass: PatternMatcherPass): def pattern( - result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): + result = torch.empty( + input.shape, device=input.device, dtype=self.quant_dtype + ) at = auto_functionalized( RMS_ADD_OP, input=input, @@ -345,12 +332,12 @@ def pattern( return at1[1], at[2], at1[2] def replacement( - result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): + result = torch.empty_like(input, dtype=self.quant_dtype) at = auto_functionalized( self.FUSED_OP, result=result, @@ -366,7 +353,6 @@ def replacement( return at[1], at[3], at[2] inputs = [ - torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result empty_bf16(5, 4), # input empty_bf16(5, 4), # residual empty_bf16(1, 5), # weight From f3b4cf190736f949eab00d1ee3a1846a409770c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 9 Sep 2025 09:48:53 -0700 Subject: [PATCH 002/137] TEMP Mostly working MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion.py | 37 +++++++- vllm/_custom_ops.py | 2 +- vllm/compilation/fusion.py | 99 ++++---------------- vllm/compilation/matcher_utils.py | 116 ++++++++++++++++++++++++ vllm/model_executor/layers/layernorm.py | 54 +++++++---- 5 files changed, 204 insertions(+), 104 deletions(-) create mode 100644 vllm/compilation/matcher_utils.py diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 7c2233643229..fb17dfd0dd46 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -8,6 +8,7 @@ from vllm.compilation.fusion import ( FUSED_OPS, QUANT_OPS, + RMS_OP, FusedRMSQuantKey, RMSNormQuantFusionPass, ) @@ -65,6 +66,9 @@ def __init__( act_quant_group_shape=group_shape, ) + self.enable_rms_norm = self.norm[0].enabled() + self.enable_quant_fp8 = self.fp8_linear.quant_fp8.enabled() + def forward(self, x): resid = torch.sqrt(x) y = self.norm[0](x) @@ -82,7 +86,18 @@ def forward(self, x): return y3 def ops_in_model_before(self): - return [QUANT_OPS[self.key]] + ops = [] + if self.enable_rms_norm: + ops += [RMS_OP] + else: + ops += [torch.ops.aten.rsqrt.default] + + if self.enable_quant_fp8: + ops += [QUANT_OPS[self.key]] + else: + ops += [torch.ops.aten.reciprocal.default] + + return ops def ops_in_model_after(self): return [ @@ -91,11 +106,13 @@ def ops_in_model_after(self): ] -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float16]) # , torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("static", [True, False]) +@pytest.mark.parametrize("enable_rms_norm", [True]) # , False]) +@pytest.mark.parametrize("enable_quant_fp8", [True]) # , False]) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. @pytest.mark.parametrize( @@ -105,17 +122,29 @@ def ops_in_model_after(self): not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm" ) def test_fusion_rmsnorm_quant( - dtype, hidden_size, num_tokens, eps, static, cuda_force_torch + dtype, + hidden_size, + num_tokens, + eps, + static, + enable_rms_norm, + enable_quant_fp8, + cuda_force_torch, ): torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) maybe_create_device_identity() # needed for certain non-cutlass fp8 paths + custom_ops = [] + if enable_rms_norm: + custom_ops.append("+rms_norm") + if enable_quant_fp8: + custom_ops.append("+quant_fp8") vllm_config = VllmConfig( compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, - custom_ops=["+rms_norm", "+quant_fp8"], + custom_ops=custom_ops, pass_config=PassConfig(enable_fusion=True, enable_noop=True), ) ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index eac0a5009e81..646d8de39a45 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1507,7 +1507,7 @@ def scaled_fp8_quant( output, input, scale, scale_ub ) else: - scale = torch.empty(1, device=input.device, dtype=torch.float32) + scale = torch.empty((1, 1), device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: assert scale.numel() == 1, f"{scale.shape}" diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 71a3153bf0bc..4afb8ba537e7 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -24,6 +24,7 @@ from vllm.platforms import current_platform from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherQuant, MatcherRMSNorm from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -99,6 +100,9 @@ def __init__(self, epsilon: float, key: FusedRMSQuantKey): assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" self.FUSED_OP = FUSED_OPS[key] + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) + self.quant_matcher = MatcherQuant(key.quant) + class RMSNormStaticQuantPattern(RMSNormQuantPattern): def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): @@ -113,25 +117,8 @@ def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): def register(self, pm_pass: PatternMatcherPass): # Cannot use methods, as the self argument affects tracing def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): - result_rms = torch.empty_like(input) - # TODO: why does empty_like produce a permute but - # empty via shape doesn't? - result = torch.empty( - input.shape, device=input.device, dtype=self.quant_dtype - ) - at1 = auto_functionalized( - RMS_OP, - result=result_rms, - input=input, - weight=weight, - epsilon=self.epsilon, - ) - at2 = auto_functionalized( - self.QUANT_OP, result=result, input=at1[1], scale=scale - ) - - # result - return at2[1] + result_rms = self.rmsnorm_matcher(input, weight) + return self.quant_matcher(result_rms, scale) def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): result = torch.empty_like(input, dtype=self.quant_dtype) @@ -173,22 +160,10 @@ def pattern( weight: torch.Tensor, scale: torch.Tensor, ): - result = torch.empty( - input.shape, device=input.device, dtype=self.quant_dtype - ) - at = auto_functionalized( - RMS_ADD_OP, - input=input, - residual=residual, - weight=weight, - epsilon=self.epsilon, - ) - at1 = auto_functionalized( - self.QUANT_OP, result=result, input=at[1], scale=scale - ) + result_rms, residual = self.rmsnorm_matcher(input, weight, residual) + result = self.quant_matcher(result_rms, scale) - # result, residual - return at1[1], at[2] + return result, residual def replacement( input: torch.Tensor, @@ -242,27 +217,14 @@ def __init__( super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): - result_rms = torch.empty_like(input) - result = torch.empty( - input.shape, device=input.device, dtype=self.quant_dtype - ) - at1 = auto_functionalized( - RMS_OP, - result=result_rms, - input=input, - weight=weight, - epsilon=self.epsilon, - ) - at2 = auto_functionalized( - self.QUANT_OP, result=result, input=at1[1], scale=scale, scale_ub=None - ) - + def pattern(input: torch.Tensor, weight: torch.Tensor): + result_rms = self.rmsnorm_matcher(input, weight) # result, scale - return at2[1], at2[2] + return self.quant_matcher(result_rms) - def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + def replacement(input: torch.Tensor, weight: torch.Tensor): result = torch.empty_like(input, dtype=self.quant_dtype) + scale = self.quant_matcher.make_scale(input) at = auto_functionalized( self.FUSED_OP, result=result, @@ -280,7 +242,6 @@ def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): inputs = [ empty_bf16(5, 4), # input empty_bf16(1, 5), # weight - empty_fp32(1, 1), # scale ] pm.register_replacement( @@ -308,36 +269,17 @@ def __init__( super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - def pattern( - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): - result = torch.empty( - input.shape, device=input.device, dtype=self.quant_dtype - ) - at = auto_functionalized( - RMS_ADD_OP, - input=input, - residual=residual, - weight=weight, - epsilon=self.epsilon, - ) - at1 = auto_functionalized( - self.QUANT_OP, result=result, input=at[1], scale=scale, scale_ub=None - ) + def pattern(input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor): + result_rms, residual = self.rmsnorm_matcher(input, weight, residual) + result, scale = self.quant_matcher(result_rms) - # result, residual, scale - return at1[1], at[2], at1[2] + return result, residual, scale def replacement( - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor ): result = torch.empty_like(input, dtype=self.quant_dtype) + scale = self.quant_matcher.make_scale(input) at = auto_functionalized( self.FUSED_OP, result=result, @@ -356,7 +298,6 @@ def replacement( empty_bf16(5, 4), # input empty_bf16(5, 4), # residual empty_bf16(1, 5), # weight - empty_fp32(1, 1), # scale ] pm.register_replacement( diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py new file mode 100644 index 000000000000..1200e236bae4 --- /dev/null +++ b/vllm/compilation/matcher_utils.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional, Union + +import torch +from torch._higher_order_ops import auto_functionalized +from torch._ops import OpOverload + +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + _normalize_quant_group_shape, + kFp8DynamicTensorSym, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, +) + +RMS_OP = torch.ops._C.rms_norm.default +RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default + +QUANT_OPS: dict[QuantKey, OpOverload] = { + kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 +} + +# TODO +# if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): +# QUANT_OPS[ +# kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 + + +class MatcherRMSNorm: + def __init__(self, epsilon: float): + self.epsilon = epsilon + + def forward( + self, + input: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + if residual is None: + result = torch.empty_like(input) + _, result = auto_functionalized( + RMS_OP, + result=result, + input=input, + weight=weight, + epsilon=self.epsilon, + ) + + return result + else: + _, result, residual = auto_functionalized( + RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon, + ) + + return result, residual + + def __call__( + self, + input: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + return self.forward(input, weight, residual) + + +class MatcherQuant: + def __init__(self, quant_key: QuantKey): + self.quant_key = quant_key + assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}" + self.QUANT_OP = QUANT_OPS[quant_key] + + def forward( + self, input: torch.Tensor, scale: Optional[torch.Tensor] = None + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + # TODO: why does empty_like produce a permute but + # empty via shape doesn't? + result = torch.empty( + input.shape, device=input.device, dtype=self.quant_key.dtype + ) + + if self.quant_key.scale.static: + assert scale is not None + _, result = auto_functionalized( + self.QUANT_OP, result=result, input=input, scale=scale + ) + return result + else: + assert scale is None + scale = self.make_scale(input) + _, result, scale = auto_functionalized( + self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None + ) + return result, scale + + def make_scale(self, input: torch.Tensor): + normalized_group_shape = _normalize_quant_group_shape( + input, self.quant_key.scale.group_shape + ) + scale_shape = ( + input.shape[0] // normalized_group_shape[0], + input.shape[1] // normalized_group_shape[1], + ) + + return torch.empty(scale_shape, device=input.device, dtype=torch.float32) + + def __call__( + self, input: torch.Tensor, scale: Optional[torch.Tensor] = None + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + return self.forward(input, scale) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 6a49ae42ca89..3c58832cad4c 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -170,13 +170,10 @@ def __init__( self.variance_size_override = ( None if var_hidden_size == hidden_size else var_hidden_size ) - self.has_weight = has_weight - if dtype is not None: - self.weight = torch.ones(hidden_size, dtype=dtype) - else: - self.weight = torch.ones(hidden_size) - if self.has_weight: - self.weight = nn.Parameter(self.weight) + self.weight = None + if has_weight: + dtype = dtype or torch.get_default_dtype() + self.weight = nn.Parameter(torch.ones(hidden_size, dtype=dtype)) weight_dtype = self.weight.data.dtype if current_platform.is_rocm(): @@ -187,9 +184,13 @@ def __init__( with_fused_add=True, dtype=weight_dtype ) - def forward_native( - self, + @staticmethod + def forward_static( x: torch.Tensor, + variance_epsilon: float, + hidden_size: int, + variance_size_override: Optional[int], + weight: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" @@ -199,35 +200,48 @@ def forward_native( x = x + residual.to(torch.float32) residual = x.to(orig_dtype) - hidden_size = x.shape[-1] - if hidden_size != self.hidden_size: + if x.shape[-1] != hidden_size: raise ValueError( - "Expected hidden_size to be " - f"{self.hidden_size}, but found: {hidden_size}" + f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}" ) - if self.variance_size_override is None: + if variance_size_override is None: x_var = x else: - if hidden_size < self.variance_size_override: + if hidden_size < variance_size_override: raise ValueError( "Expected hidden_size to be at least " - f"{self.variance_size_override}, but found: {hidden_size}" + f"{variance_size_override}, but found: {hidden_size}" ) - x_var = x[:, :, : self.variance_size_override] + x_var = x[:, :, :variance_size_override] variance = x_var.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x * torch.rsqrt(variance + variance_epsilon) x = x.to(orig_dtype) - if self.has_weight: - x = x * self.weight + if weight is not None: + x = x * weight if residual is None: return x else: return x, residual + def forward_native( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """PyTorch-native implementation equivalent to forward().""" + return self.forward_static( + x, + self.variance_epsilon, + self.hidden_size, + self.variance_size_override, + self.weight.data, + residual, + ) + def forward_cuda( self, x: torch.Tensor, From cdad3c05ea12ddba69805277085e2cf658a065f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 12 Sep 2025 12:11:48 -0700 Subject: [PATCH 003/137] TEMP: fixed rmsnorm issue (TODO assert dtypes in fused norm_quant kernels) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- csrc/layernorm_kernels.cu | 1 + tests/compile/backend.py | 19 +- tests/compile/test_fusion.py | 100 ++++----- vllm/compilation/fusion.py | 277 +++++++++++------------- vllm/compilation/matcher_utils.py | 105 ++++++--- vllm/model_executor/layers/layernorm.py | 4 +- 6 files changed, 261 insertions(+), 245 deletions(-) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 6c3685f6f7cd..b738cdbbdc53 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -380,6 +380,7 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] double epsilon) { + TORCH_CHECK(input.scalar_type() == residual.scalar_type()); TORCH_CHECK(residual.is_contiguous()); TORCH_CHECK(weight.is_contiguous()); int hidden_size = input.size(-1); diff --git a/tests/compile/backend.py b/tests/compile/backend.py index 36bc832a1329..113906af0203 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -4,10 +4,13 @@ import weakref from collections.abc import Sequence from copy import deepcopy +from pathlib import Path from typing import Callable, Union +import depyf from torch import fx from torch._ops import OpOverload +from torch.fx._utils import lazy_format_graph_code from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.inductor_pass import InductorPass @@ -46,11 +49,20 @@ class TestBackend: def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]): self.custom_passes = list(passes) - compile_config = get_current_vllm_config().compilation_config + vllm_config = get_current_vllm_config() + compile_config = vllm_config.compilation_config self.inductor_config = compile_config.inductor_compile_config self.inductor_config["force_disable_caches"] = True self.inductor_config["post_grad_custom_post_pass"] = self.post_pass + if compile_config.debug_dump_path: + self.debug_dump_path = (Path(compile_config.debug_dump_path) / + f"rank_{vllm_config.parallel_config.rank}") + self.ctx = depyf.prepare_debug(str(self.debug_dump_path)) + self.ctx.__enter__() + else: + self.ctx = None + def __call__(self, graph: fx.GraphModule, example_inputs): self.graph_pre_compile = deepcopy(graph) from torch._inductor.compile_fx import compile_fx @@ -60,6 +72,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs): @with_pattern_match_debug def post_pass(self, graph: fx.Graph): self.graph_pre_pass = deepcopy(graph) + lazy_format_graph_code("graph_pre_pass", graph.owning_module) VllmInductorPass.dump_prefix = 0 for pass_ in self.custom_passes: @@ -69,9 +82,13 @@ def post_pass(self, graph: fx.Graph): VllmInductorPass.dump_prefix = None self.graph_post_pass = deepcopy(graph) + lazy_format_graph_code("graph_post_pass", graph.owning_module) # assign by reference, will reflect the final state of the graph self.final_graph = graph + if self.ctx is not None: + self.ctx.__exit__(None, None, None) + def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True): for op in ops: num_pre = len(list(find_op_nodes(op, self.graph_pre_pass))) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index fb17dfd0dd46..ac5d9b9c93bf 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -5,27 +5,17 @@ import torch import vllm.plugins -from vllm.compilation.fusion import ( - FUSED_OPS, - QUANT_OPS, - RMS_OP, - FusedRMSQuantKey, - RMSNormQuantFusionPass, -) +from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, RMS_OP, + FusedRMSQuantKey, RMSNormQuantFusionPass) from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass -from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig +from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, + VllmConfig) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, - QuantKey, - ScaleDesc, -) + GroupShape, QuantKey, ScaleDesc) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, - cutlass_fp8_supported, - maybe_create_device_identity, -) + Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity) from vllm.platforms import current_platform from ..utils import override_cutlass_fp8_supported @@ -35,15 +25,9 @@ class TestModel(torch.nn.Module): - def __init__( - self, - hidden_size: int, - eps: float, - static: bool, - cuda_force_torch: bool, - *args, - **kwargs, - ): + + def __init__(self, hidden_size: int, eps: float, static: bool, + cuda_force_torch: bool, *args, **kwargs): super().__init__(*args, **kwargs) self.cuda_force_torch = cuda_force_torch self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] @@ -70,18 +54,21 @@ def __init__( self.enable_quant_fp8 = self.fp8_linear.quant_fp8.enabled() def forward(self, x): - resid = torch.sqrt(x) + # avoid having graph input be an arg to a pattern directly + x = resid = torch.relu(x) y = self.norm[0](x) - x2 = self.fp8_linear.apply( - y, self.w[0], self.wscale[0], input_scale=self.scale[0] - ) + x2 = self.fp8_linear.apply(y, + self.w[0], + self.wscale[0], + input_scale=self.scale[0]) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - x3 = self.fp8_linear.apply( - y2, self.w[1], self.wscale[1], input_scale=self.scale[1] - ) + x3 = self.fp8_linear.apply(y2, + self.w[1], + self.wscale[1], + input_scale=self.scale[1]) y3, resid = self.norm[2](x3, resid) # use resid here return y3 @@ -102,35 +89,26 @@ def ops_in_model_before(self): def ops_in_model_after(self): return [ FUSED_OPS[FusedRMSQuantKey(self.key, False)], - FUSED_OPS[FusedRMSQuantKey(self.key, True)], + FUSED_OPS[FusedRMSQuantKey(self.key, True)] ] -@pytest.mark.parametrize("dtype", [torch.float16]) # , torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float16]) #, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("static", [True, False]) -@pytest.mark.parametrize("enable_rms_norm", [True]) # , False]) -@pytest.mark.parametrize("enable_quant_fp8", [True]) # , False]) +@pytest.mark.parametrize("enable_rms_norm", [True, False]) +@pytest.mark.parametrize("enable_quant_fp8", [True, False]) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. -@pytest.mark.parametrize( - "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True] -) -@pytest.mark.skipif( - not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm" -) -def test_fusion_rmsnorm_quant( - dtype, - hidden_size, - num_tokens, - eps, - static, - enable_rms_norm, - enable_quant_fp8, - cuda_force_torch, -): +@pytest.mark.parametrize("cuda_force_torch", + [True, False] if cutlass_fp8_supported() else [True]) +@pytest.mark.skipif(not current_platform.is_cuda_alike(), + reason="Only test on CUDA and ROCm") +def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, + enable_rms_norm, enable_quant_fp8, + cuda_force_torch): torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) @@ -141,13 +119,13 @@ def test_fusion_rmsnorm_quant( custom_ops.append("+rms_norm") if enable_quant_fp8: custom_ops.append("+quant_fp8") - vllm_config = VllmConfig( - compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - custom_ops=custom_ops, - pass_config=PassConfig(enable_fusion=True, enable_noop=True), - ) - ) + vllm_config = VllmConfig(compilation_config=CompilationConfig( + debug_dump_path=f"/home/luka/git/vllm/._workspace/" + f"debug_dump_{enable_rms_norm}_{enable_quant_fp8}", + level=CompilationLevel.PIECEWISE, + custom_ops=custom_ops, + pass_config=PassConfig(enable_fusion=True, enable_noop=True), + )) with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work noop_pass = NoOpEliminationPass(vllm_config) @@ -179,7 +157,7 @@ def test_fusion_rmsnorm_quant( assert fusion_pass.matched_count == 2 # In pre-nodes, fp8 quant should be there and fused kernels should not - backend.check_before_ops(model.ops_in_model_before()) + # backend.check_before_ops(model.ops_in_model_before()) # In post-nodes, fused kernels should be there and fp8 quant should not - backend.check_after_ops(model.ops_in_model_after()) + # backend.check_after_ops(model.ops_in_model_after()) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 4afb8ba537e7..8e3a1de99898 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -12,15 +12,8 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, - QuantKey, - ScaleDesc, - kFp8DynamicTensorSym, - kFp8DynamicTokenSym, - kFp8StaticTensorSym, - kNvfp4Quant, - kStaticTensorScale, -) + GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym, + kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale) from vllm.platforms import current_platform from .inductor_pass import enable_fake_mode @@ -48,9 +41,12 @@ def empty_i32(*args, **kwargs): RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default QUANT_OPS: dict[QuantKey, OpOverload] = { - kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 + kFp8StaticTensorSym: + torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTensorSym: + torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTokenSym: + torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default @@ -62,42 +58,38 @@ class FusedRMSQuantKey(NamedTuple): quant: type of quantization fused_add: does the op also perform the residual add """ - quant: QuantKey fused_add: bool def __str__(self): - return ( - f"FusedQuantKey({self.quant}, with" - f"{'' if self.fused_add else 'out'} residual)" - ) + return (f"FusedQuantKey({self.quant}, with" + f"{'' if self.fused_add else 'out'} residual)") FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { - FusedRMSQuantKey( - kFp8StaticTensorSym, False - ): torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501 - FusedRMSQuantKey( - kFp8StaticTensorSym, True - ): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501 - FusedRMSQuantKey( - kFp8DynamicTokenSym, False - ): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 - FusedRMSQuantKey( - kFp8DynamicTokenSym, True - ): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 + FusedRMSQuantKey(kFp8StaticTensorSym, False): + torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501 + FusedRMSQuantKey(kFp8StaticTensorSym, True): + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501 + FusedRMSQuantKey(kFp8DynamicTokenSym, False): + torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 + FusedRMSQuantKey(kFp8DynamicTokenSym, True): + torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 } class RMSNormQuantPattern: + def __init__(self, epsilon: float, key: FusedRMSQuantKey): self.epsilon = epsilon self.quant_dtype = key.quant.dtype - assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}" + assert key.quant in QUANT_OPS, \ + f"unsupported quantization scheme {key.quant}" self.QUANT_OP = QUANT_OPS[key.quant] - assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" + assert key in FUSED_OPS, \ + f"unsupported fused rmsnorm+quant op for {key}" self.FUSED_OP = FUSED_OPS[key] self.rmsnorm_matcher = MatcherRMSNorm(epsilon) @@ -105,82 +97,80 @@ def __init__(self, epsilon: float, key: FusedRMSQuantKey): class RMSNormStaticQuantPattern(RMSNormQuantPattern): - def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): - fused_key = FusedRMSQuantKey( - fused_add=False, - quant=QuantKey( - dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric - ), - ) + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + symmetric=True): + fused_key = FusedRMSQuantKey(fused_add=False, + quant=QuantKey(dtype=quant_dtype, + scale=kStaticTensorScale, + symmetric=symmetric)) super().__init__(epsilon, fused_key) def register(self, pm_pass: PatternMatcherPass): # Cannot use methods, as the self argument affects tracing - def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + def pattern(input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): result_rms = self.rmsnorm_matcher(input, weight) return self.quant_matcher(result_rms, scale) - def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + def replacement(input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): result = torch.empty_like(input, dtype=self.quant_dtype) - at = auto_functionalized( - self.FUSED_OP, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=self.epsilon, - ) + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon) # result return at[1] inputs = [ empty_bf16(5, 4), # input - empty_bf16(1, 5), # weight - empty_fp32(1, 1), # scale + empty_bf16(4,), # weight + empty_fp32(1, 1) # scale ] + pattern(*inputs) - pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, + pm_pass) class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): - def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): - key = FusedRMSQuantKey( - fused_add=True, - quant=QuantKey( - dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric - ), - ) + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + symmetric=True): + key = FusedRMSQuantKey(fused_add=True, + quant=QuantKey(dtype=quant_dtype, + scale=kStaticTensorScale, + symmetric=symmetric)) super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - def pattern( - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): - result_rms, residual = self.rmsnorm_matcher(input, weight, residual) + + def pattern(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, scale: torch.Tensor): + result_rms, residual = self.rmsnorm_matcher( + input, weight, residual) result = self.quant_matcher(result_rms, scale) return result, residual - def replacement( - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): + def replacement(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, scale: torch.Tensor): result = torch.empty_like(input, dtype=self.quant_dtype) - at = auto_functionalized( - self.FUSED_OP, - result=result, - input=input, - residual=residual, - weight=weight, - scale=scale, - epsilon=self.epsilon, - ) + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + residual=residual, + weight=weight, + scale=scale, + epsilon=self.epsilon) # result, residual return at[1], at[2] @@ -188,8 +178,8 @@ def replacement( inputs = [ empty_bf16(5, 4), # input empty_bf16(5, 4), # residual - empty_bf16(1, 5), # weight - empty_fp32(1, 1), # scale + empty_bf16(4, ), # weight + empty_fp32(1, 1) # scale ] pm.register_replacement( @@ -202,21 +192,21 @@ def replacement( class RMSNormDynamicQuantPattern(RMSNormQuantPattern): - def __init__( - self, - epsilon: float, - quant_dtype: torch.dtype, - group_shape: GroupShape = GroupShape.PER_TOKEN, - symmetric=True, - ): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape = GroupShape.PER_TOKEN, + symmetric=True): scale = ScaleDesc(torch.float32, False, group_shape) - key = FusedRMSQuantKey( - fused_add=False, - quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), - ) + key = FusedRMSQuantKey(fused_add=False, + quant=QuantKey(dtype=quant_dtype, + scale=scale, + symmetric=symmetric)) super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): + def pattern(input: torch.Tensor, weight: torch.Tensor): result_rms = self.rmsnorm_matcher(input, weight) # result, scale @@ -225,23 +215,21 @@ def pattern(input: torch.Tensor, weight: torch.Tensor): def replacement(input: torch.Tensor, weight: torch.Tensor): result = torch.empty_like(input, dtype=self.quant_dtype) scale = self.quant_matcher.make_scale(input) - at = auto_functionalized( - self.FUSED_OP, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=self.epsilon, - scale_ub=None, - residual=None, - ) + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=None) # result, scale return at[1], at[2] inputs = [ empty_bf16(5, 4), # input - empty_bf16(1, 5), # weight + empty_bf16(4), # weight ] pm.register_replacement( @@ -254,42 +242,41 @@ def replacement(input: torch.Tensor, weight: torch.Tensor): class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): - def __init__( - self, - epsilon: float, - quant_dtype: torch.dtype, - group_shape: GroupShape = GroupShape.PER_TOKEN, - symmetric=True, - ): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape = GroupShape.PER_TOKEN, + symmetric=True): scale = ScaleDesc(torch.float32, False, group_shape) - key = FusedRMSQuantKey( - fused_add=True, - quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), - ) + key = FusedRMSQuantKey(fused_add=True, + quant=QuantKey(dtype=quant_dtype, + scale=scale, + symmetric=symmetric)) super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - def pattern(input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor): - result_rms, residual = self.rmsnorm_matcher(input, weight, residual) + + def pattern(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor): + result_rms, residual = self.rmsnorm_matcher( + input, weight, residual) result, scale = self.quant_matcher(result_rms) return result, residual, scale - def replacement( - input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor - ): + def replacement(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor): result = torch.empty_like(input, dtype=self.quant_dtype) scale = self.quant_matcher.make_scale(input) - at = auto_functionalized( - self.FUSED_OP, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=self.epsilon, - scale_ub=None, - residual=residual, - ) + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=residual) # result, residual, scale return at[1], at[3], at[2] @@ -297,7 +284,7 @@ def replacement( inputs = [ empty_bf16(5, 4), # input empty_bf16(5, 4), # residual - empty_bf16(1, 5), # weight + empty_bf16(4), # weight ] pm.register_replacement( @@ -320,25 +307,24 @@ def __init__(self, config: VllmConfig): super().__init__(config) self.patterns: PatternMatcherPass = PatternMatcherPass( - pass_name="rmsnorm_quant_fusion_pass" - ) + pass_name="rmsnorm_quant_fusion_pass") for epsilon in [1e-5, 1e-6]: # Fuse rms_norm + static fp8 quant - RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) + RMSNormStaticQuantPattern(epsilon, + FP8_DTYPE).register(self.patterns) # Fuse fused_add_rms_norm + static fp8 quant FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns - ) + self.patterns) # Fuse rms_norm + dynamic per-token fp8 quant - RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) + RMSNormDynamicQuantPattern(epsilon, + FP8_DTYPE).register(self.patterns) # Fuse fused_add_rms_norm + dynamic per-token fp8 quant FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns - ) + self.patterns) self.dump_patterns(config, self.patterns) @@ -348,11 +334,8 @@ def __call__(self, graph: fx.Graph): logger.debug("Replaced %s patterns", self.matched_count) def uuid(self) -> Any: - return self.hash_source( - self, - RMSNormQuantPattern, - RMSNormStaticQuantPattern, - RMSNormDynamicQuantPattern, - FusedAddRMSNormStaticQuantPattern, - FusedAddRMSNormDynamicQuantPattern, - ) + return self.hash_source(self, RMSNormQuantPattern, + RMSNormStaticQuantPattern, + RMSNormDynamicQuantPattern, + FusedAddRMSNormStaticQuantPattern, + FusedAddRMSNormDynamicQuantPattern) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 1200e236bae4..1b88d2916b0d 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -6,21 +6,21 @@ from torch._higher_order_ops import auto_functionalized from torch._ops import OpOverload +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, - _normalize_quant_group_shape, - kFp8DynamicTensorSym, - kFp8DynamicTokenSym, - kFp8StaticTensorSym, -) + QuantKey, _normalize_quant_group_shape, kFp8DynamicTensorSym, + kFp8DynamicTokenSym, kFp8StaticTensorSym) RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default QUANT_OPS: dict[QuantKey, OpOverload] = { - kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 + kFp8StaticTensorSym: + torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTensorSym: + torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTokenSym: + torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } # TODO @@ -29,11 +29,18 @@ # kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 -class MatcherRMSNorm: - def __init__(self, epsilon: float): +class MatcherRMSNorm: # TODO separate residual and not residual + + def __init__(self, epsilon: float, enabled: Optional[bool] = None): self.epsilon = epsilon - def forward( + if enabled is None: + # TODO either pass config to enabled or set it globally (global during pass init seems reasonable) + enabled = RMSNorm.enabled() + + self.forward = self.forward_custom if enabled else self.forward_native + + def forward_custom( self, input: torch.Tensor, weight: torch.Tensor, @@ -51,16 +58,36 @@ def forward( return result else: - _, result, residual = auto_functionalized( - RMS_ADD_OP, - input=input, - residual=residual, - weight=weight, - epsilon=self.epsilon, - ) + _, result, residual = auto_functionalized(RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon) return result, residual + def forward_native( + self, + input: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + orig_dtype = input.dtype + x = input.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x + + variance = x.pow(2).mean(dim=-1, keepdim=True) + + x = x * torch.rsqrt(variance + self.epsilon) + x = x.to(orig_dtype) + if weight is not None: + x = x * weight + + return x if residual is None else (x, residual) + + def __call__( self, input: torch.Tensor, @@ -71,46 +98,56 @@ def __call__( class MatcherQuant: + def __init__(self, quant_key: QuantKey): self.quant_key = quant_key - assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}" + assert quant_key in QUANT_OPS, \ + f"unsupported quantization scheme {quant_key}" self.QUANT_OP = QUANT_OPS[quant_key] def forward( - self, input: torch.Tensor, scale: Optional[torch.Tensor] = None + self, + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: # TODO: why does empty_like produce a permute but # empty via shape doesn't? - result = torch.empty( - input.shape, device=input.device, dtype=self.quant_key.dtype - ) + result = torch.empty(input.shape, + device=input.device, + dtype=self.quant_key.dtype) if self.quant_key.scale.static: assert scale is not None - _, result = auto_functionalized( - self.QUANT_OP, result=result, input=input, scale=scale - ) + _, result = auto_functionalized(self.QUANT_OP, + result=result, + input=input, + scale=scale) return result else: assert scale is None scale = self.make_scale(input) - _, result, scale = auto_functionalized( - self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None - ) + _, result, scale = auto_functionalized(self.QUANT_OP, + result=result, + input=input, + scale=scale, + scale_ub=None) return result, scale def make_scale(self, input: torch.Tensor): normalized_group_shape = _normalize_quant_group_shape( - input, self.quant_key.scale.group_shape - ) + input, self.quant_key.scale.group_shape) scale_shape = ( input.shape[0] // normalized_group_shape[0], input.shape[1] // normalized_group_shape[1], ) - return torch.empty(scale_shape, device=input.device, dtype=torch.float32) + return torch.empty(scale_shape, + device=input.device, + dtype=torch.float32) def __call__( - self, input: torch.Tensor, scale: Optional[torch.Tensor] = None + self, + input: torch.Tensor, + scale: Optional[torch.Tensor] = None ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: return self.forward(input, scale) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 3c58832cad4c..976b2e852265 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -189,9 +189,9 @@ def forward_static( x: torch.Tensor, variance_epsilon: float, hidden_size: int, - variance_size_override: Optional[int], weight: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None, + variance_size_override: Optional[int] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" orig_dtype = x.dtype @@ -237,9 +237,9 @@ def forward_native( x, self.variance_epsilon, self.hidden_size, - self.variance_size_override, self.weight.data, residual, + self.variance_size_override, ) def forward_cuda( From 8e4a56f57581e98bfa0e146197cfff860a2f95f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 16 Sep 2025 10:47:13 -0700 Subject: [PATCH 004/137] rms works fully now, had to remove more conversions (and add them in replacements). TODO pass to remove unnecessary conversions? MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- ...fused_layernorm_dynamic_per_token_quant.cu | 4 ++ tests/compile/test_fusion.py | 21 ++++---- vllm/compilation/fusion.py | 53 +++++++++++++------ vllm/compilation/matcher_utils.py | 15 +++--- 4 files changed, 61 insertions(+), 32 deletions(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 95aa92e25b30..92d6c2f402a2 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -145,7 +145,11 @@ void rms_norm_dynamic_per_token_quant( if (scale_ub.has_value()) { TORCH_CHECK(out.dtype() == kFp8Type); } + TORCH_CHECK(weight.dtype() == input.dtype()); TORCH_CHECK(scales.dtype() == torch::kFloat32); + if (residual) { + TORCH_CHECK(residual->scalar_type() == input.scalar_type()); + } VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] { diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index ac5d9b9c93bf..aea9038a64e3 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -9,8 +9,8 @@ FusedRMSQuantKey, RMSNormQuantFusionPass) from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass -from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, - VllmConfig) +from vllm.config import (CompilationConfig, CompilationLevel, ModelConfig, + PassConfig, VllmConfig) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, ScaleDesc) @@ -119,13 +119,16 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, custom_ops.append("+rms_norm") if enable_quant_fp8: custom_ops.append("+quant_fp8") - vllm_config = VllmConfig(compilation_config=CompilationConfig( - debug_dump_path=f"/home/luka/git/vllm/._workspace/" - f"debug_dump_{enable_rms_norm}_{enable_quant_fp8}", - level=CompilationLevel.PIECEWISE, - custom_ops=custom_ops, - pass_config=PassConfig(enable_fusion=True, enable_noop=True), - )) + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=dtype), + compilation_config=CompilationConfig( + debug_dump_path=f"/home/luka/git/vllm/._workspace/" + f"debug_dump_{enable_rms_norm}_{enable_quant_fp8}", + level=CompilationLevel.PIECEWISE, + custom_ops=custom_ops, + pass_config=PassConfig(enable_fusion=True, enable_noop=True), + ), + ) with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work noop_pass = NoOpEliminationPass(vllm_config) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 8e3a1de99898..0efdd7d2d0e4 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -9,7 +9,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from torch._ops import OpOverload -from vllm.config import VllmConfig +from vllm.config import VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym, @@ -117,6 +117,10 @@ def pattern(input: torch.Tensor, weight: torch.Tensor, def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + # In case we're matching native rms-norm, conversions might be + # optimized out. We convert here just to be safe. + input = input.to(dtype=torch.float16) # TODO model dtype + result = torch.empty_like(input, dtype=self.quant_dtype) at = auto_functionalized(self.FUSED_OP, result=result, @@ -130,7 +134,7 @@ def replacement(input: torch.Tensor, weight: torch.Tensor, inputs = [ empty_bf16(5, 4), # input - empty_bf16(4,), # weight + empty_bf16(4, ), # weight empty_fp32(1, 1) # scale ] pattern(*inputs) @@ -163,6 +167,11 @@ def pattern(input: torch.Tensor, residual: torch.Tensor, def replacement(input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + # In case we're matching native rms-norm, conversions might be + # optimized out. We convert here just to be safe. + input = input.to(dtype=torch.float16) # TODO model dtype + residual = residual.to(dtype=torch.float16) + result = torch.empty_like(input, dtype=self.quant_dtype) at = auto_functionalized(self.FUSED_OP, result=result, @@ -176,9 +185,11 @@ def replacement(input: torch.Tensor, residual: torch.Tensor, return at[1], at[2] inputs = [ + # TODO: maybe 32bit for torch impl? + # TODO dtype doesn't seem to matter? empty_bf16(5, 4), # input empty_bf16(5, 4), # residual - empty_bf16(4, ), # weight + empty_bf16(4, ), # weight empty_fp32(1, 1) # scale ] @@ -213,6 +224,10 @@ def pattern(input: torch.Tensor, weight: torch.Tensor): return self.quant_matcher(result_rms) def replacement(input: torch.Tensor, weight: torch.Tensor): + # In case we're matching native rms-norm, conversions might be + # optimized out. We convert here just to be safe. + input = input.to(dtype=torch.float16) # TODO model dtype + result = torch.empty_like(input, dtype=self.quant_dtype) scale = self.quant_matcher.make_scale(input) at = auto_functionalized(self.FUSED_OP, @@ -267,6 +282,11 @@ def pattern(input: torch.Tensor, residual: torch.Tensor, def replacement(input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor): + # In case we're matching native rms-norm, conversions might be + # optimized out. We convert here just to be safe. + input = input.to(dtype=torch.float16) # TODO model dtype + residual = residual.to(dtype=torch.float16) + result = torch.empty_like(input, dtype=self.quant_dtype) scale = self.quant_matcher.make_scale(input) at = auto_functionalized(self.FUSED_OP, @@ -309,22 +329,23 @@ def __init__(self, config: VllmConfig): self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="rmsnorm_quant_fusion_pass") - for epsilon in [1e-5, 1e-6]: - # Fuse rms_norm + static fp8 quant - RMSNormStaticQuantPattern(epsilon, - FP8_DTYPE).register(self.patterns) + with set_current_vllm_config(config, check_compile=False): + for epsilon in [1e-5, 1e-6]: + # Fuse rms_norm + static fp8 quant + RMSNormStaticQuantPattern(epsilon, + FP8_DTYPE).register(self.patterns) - # Fuse fused_add_rms_norm + static fp8 quant - FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns) + # Fuse fused_add_rms_norm + static fp8 quant + FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( + self.patterns) - # Fuse rms_norm + dynamic per-token fp8 quant - RMSNormDynamicQuantPattern(epsilon, - FP8_DTYPE).register(self.patterns) + # Fuse rms_norm + dynamic per-token fp8 quant + RMSNormDynamicQuantPattern(epsilon, + FP8_DTYPE).register(self.patterns) - # Fuse fused_add_rms_norm + dynamic per-token fp8 quant - FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns) + # Fuse fused_add_rms_norm + dynamic per-token fp8 quant + FusedAddRMSNormDynamicQuantPattern( + epsilon, FP8_DTYPE).register(self.patterns) self.dump_patterns(config, self.patterns) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 1b88d2916b0d..ebb5e26b324c 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -6,6 +6,7 @@ from torch._higher_order_ops import auto_functionalized from torch._ops import OpOverload +from vllm.config import get_current_vllm_config from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, _normalize_quant_group_shape, kFp8DynamicTensorSym, @@ -29,16 +30,18 @@ # kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 -class MatcherRMSNorm: # TODO separate residual and not residual +class MatcherRMSNorm: # TODO separate residual and not residual def __init__(self, epsilon: float, enabled: Optional[bool] = None): self.epsilon = epsilon if enabled is None: - # TODO either pass config to enabled or set it globally (global during pass init seems reasonable) + # TODO either pass config to enabled or set it globally + # (global during pass init seems reasonable) enabled = RMSNorm.enabled() self.forward = self.forward_custom if enabled else self.forward_native + self.model_dtype = get_current_vllm_config().model_config.dtype def forward_custom( self, @@ -72,22 +75,20 @@ def forward_native( weight: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - orig_dtype = input.dtype - x = input.to(torch.float32) + x = input # .to(torch.float32) if residual is not None: x = x + residual.to(torch.float32) - residual = x + residual = x # conversion to 16-bit is eliminated in full graph variance = x.pow(2).mean(dim=-1, keepdim=True) x = x * torch.rsqrt(variance + self.epsilon) - x = x.to(orig_dtype) + x = x.to(self.model_dtype) if weight is not None: x = x * weight return x if residual is None else (x, residual) - def __call__( self, input: torch.Tensor, From e151e6d16e1ef6c2c0cddf6ee9fc074f88dc3bd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 16 Sep 2025 11:08:39 -0700 Subject: [PATCH 005/137] quant works except (torch,torch) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/compilation/fusion.py | 4 ++-- vllm/compilation/matcher_utils.py | 37 +++++++++++++++++++++++-------- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 0efdd7d2d0e4..fffe2a6432ec 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -113,7 +113,7 @@ def register(self, pm_pass: PatternMatcherPass): def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): result_rms = self.rmsnorm_matcher(input, weight) - return self.quant_matcher(result_rms, scale) + return self.quant_matcher(result_rms, scale)[0] def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): @@ -161,7 +161,7 @@ def pattern(input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): result_rms, residual = self.rmsnorm_matcher( input, weight, residual) - result = self.quant_matcher(result_rms, scale) + result, _ = self.quant_matcher(result_rms, scale) return result, residual diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index ebb5e26b324c..51fff7fe0c9e 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -8,6 +8,7 @@ from vllm.config import get_current_vllm_config from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, _normalize_quant_group_shape, kFp8DynamicTensorSym, kFp8DynamicTokenSym, kFp8StaticTensorSym) @@ -100,17 +101,29 @@ def __call__( class MatcherQuant: - def __init__(self, quant_key: QuantKey): + def __init__(self, quant_key: QuantKey, enabled: Optional[bool] = None): + self.quant_key = quant_key assert quant_key in QUANT_OPS, \ f"unsupported quantization scheme {quant_key}" self.QUANT_OP = QUANT_OPS[quant_key] - def forward( + assert quant_key.scale2 is None + self.quant_fp8 = QuantFP8(quant_key.scale.static, + quant_key.scale.group_shape) + + if enabled is None: + # TODO either pass config to enabled or set it globally + # (global during pass init seems reasonable) + enabled = self.quant_fp8.enabled() + + self.forward = self.forward_custom if enabled else self.forward_native + + def forward_custom( self, input: torch.Tensor, scale: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor]: # TODO: why does empty_like produce a permute but # empty via shape doesn't? result = torch.empty(input.shape, @@ -123,7 +136,7 @@ def forward( result=result, input=input, scale=scale) - return result + return result, scale else: assert scale is None scale = self.make_scale(input) @@ -134,6 +147,13 @@ def forward( scale_ub=None) return result, scale + def forward_native( + self, + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.quant_fp8(input, scale) + def make_scale(self, input: torch.Tensor): normalized_group_shape = _normalize_quant_group_shape( input, self.quant_key.scale.group_shape) @@ -146,9 +166,8 @@ def make_scale(self, input: torch.Tensor): device=input.device, dtype=torch.float32) - def __call__( - self, - input: torch.Tensor, - scale: Optional[torch.Tensor] = None - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + def __call__(self, + input: torch.Tensor, + scale: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor]: return self.forward(input, scale) From 14fdc8b9d51ac418c8f09f367c1a81228ee163ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 18 Sep 2025 12:32:27 -0700 Subject: [PATCH 006/137] quant with fix for pure torch, broke others MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion.py | 6 ++---- vllm/compilation/fusion.py | 8 ++++---- vllm/compilation/matcher_utils.py | 10 +++++++--- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index aea9038a64e3..4a9a497989e8 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -147,10 +147,8 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, model2 = torch.compile(model, backend=backend) result2 = model2(x) - # Higher tol for dynamic, even higher for bfloat16 - if static: - ATOL, RTOL = (1e-3, 1e-3) - elif dtype == torch.float16: + # Higher tol for dynamic bfloat16 + if dtype == torch.float16 or static: ATOL, RTOL = (2e-3, 2e-3) else: ATOL, RTOL = (1e-2, 1e-2) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index fffe2a6432ec..92caf47945ef 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -26,7 +26,7 @@ def empty_bf16(*args, **kwargs): - return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") + return torch.empty(*args, **kwargs, dtype=torch.float16, device="cuda") def empty_fp32(*args, **kwargs): @@ -133,7 +133,7 @@ def replacement(input: torch.Tensor, weight: torch.Tensor, return at[1] inputs = [ - empty_bf16(5, 4), # input + empty_fp32(5, 4), # input # TODO: rms_input empty_bf16(4, ), # weight empty_fp32(1, 1) # scale ] @@ -185,8 +185,8 @@ def replacement(input: torch.Tensor, residual: torch.Tensor, return at[1], at[2] inputs = [ - # TODO: maybe 32bit for torch impl? - # TODO dtype doesn't seem to matter? + # TODO: maybe 32bit for torch impl? yes to resolve bug + # TODO dtype doesn't seem to matter? it does matter for what cvts get traced empty_bf16(5, 4), # input empty_bf16(5, 4), # residual empty_bf16(4, ), # weight diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 51fff7fe0c9e..9cde9230211f 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -43,6 +43,10 @@ def __init__(self, epsilon: float, enabled: Optional[bool] = None): self.forward = self.forward_custom if enabled else self.forward_native self.model_dtype = get_current_vllm_config().model_config.dtype + print(self.model_dtype) + + def inputs(self): + return def forward_custom( self, @@ -76,10 +80,10 @@ def forward_native( weight: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - x = input # .to(torch.float32) + x = input.to(torch.float32) if residual is not None: - x = x + residual.to(torch.float32) - residual = x # conversion to 16-bit is eliminated in full graph + x = x + residual + residual = x.to(self.model_dtype) variance = x.pow(2).mean(dim=-1, keepdim=True) From 05a65f39a5043315afcad16f9060c21079908ead Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 18 Sep 2025 13:21:46 -0700 Subject: [PATCH 007/137] ALL WORKS MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/compilation/fusion.py | 47 ++++------- vllm/compilation/matcher_utils.py | 125 ++++++++++++++++++++++-------- 2 files changed, 110 insertions(+), 62 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 92caf47945ef..4e1b569f77e0 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -17,7 +17,7 @@ from vllm.platforms import current_platform from .inductor_pass import enable_fake_mode -from .matcher_utils import MatcherQuant, MatcherRMSNorm +from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuant, MatcherRMSNorm from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -92,7 +92,8 @@ def __init__(self, epsilon: float, key: FusedRMSQuantKey): f"unsupported fused rmsnorm+quant op for {key}" self.FUSED_OP = FUSED_OPS[key] - self.rmsnorm_matcher = MatcherRMSNorm(epsilon) + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) if not key.fused_add \ + else MatcherFusedAddRMSNorm(epsilon) self.quant_matcher = MatcherQuant(key.quant) @@ -133,8 +134,8 @@ def replacement(input: torch.Tensor, weight: torch.Tensor, return at[1] inputs = [ - empty_fp32(5, 4), # input # TODO: rms_input - empty_bf16(4, ), # weight + # input, weight + *self.rmsnorm_matcher.inputs(), empty_fp32(1, 1) # scale ] pattern(*inputs) @@ -157,16 +158,16 @@ def __init__(self, def register(self, pm_pass: PatternMatcherPass): - def pattern(input: torch.Tensor, residual: torch.Tensor, - weight: torch.Tensor, scale: torch.Tensor): + def pattern(input: torch.Tensor, weight: torch.Tensor, + residual: torch.Tensor, scale: torch.Tensor): result_rms, residual = self.rmsnorm_matcher( input, weight, residual) result, _ = self.quant_matcher(result_rms, scale) return result, residual - def replacement(input: torch.Tensor, residual: torch.Tensor, - weight: torch.Tensor, scale: torch.Tensor): + def replacement(input: torch.Tensor, weight: torch.Tensor, + residual: torch.Tensor, scale: torch.Tensor): # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. input = input.to(dtype=torch.float16) # TODO model dtype @@ -185,11 +186,8 @@ def replacement(input: torch.Tensor, residual: torch.Tensor, return at[1], at[2] inputs = [ - # TODO: maybe 32bit for torch impl? yes to resolve bug - # TODO dtype doesn't seem to matter? it does matter for what cvts get traced - empty_bf16(5, 4), # input - empty_bf16(5, 4), # residual - empty_bf16(4, ), # weight + # input, weight, residual + *self.rmsnorm_matcher.inputs(), empty_fp32(1, 1) # scale ] @@ -242,15 +240,10 @@ def replacement(input: torch.Tensor, weight: torch.Tensor): # result, scale return at[1], at[2] - inputs = [ - empty_bf16(5, 4), # input - empty_bf16(4), # weight - ] - pm.register_replacement( pattern, replacement, - inputs, + self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass, ) @@ -272,16 +265,16 @@ def __init__(self, def register(self, pm_pass: PatternMatcherPass): - def pattern(input: torch.Tensor, residual: torch.Tensor, - weight: torch.Tensor): + def pattern(input: torch.Tensor, weight: torch.Tensor, + residual: torch.Tensor): result_rms, residual = self.rmsnorm_matcher( input, weight, residual) result, scale = self.quant_matcher(result_rms) return result, residual, scale - def replacement(input: torch.Tensor, residual: torch.Tensor, - weight: torch.Tensor): + def replacement(input: torch.Tensor, weight: torch.Tensor, + residual: torch.Tensor): # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. input = input.to(dtype=torch.float16) # TODO model dtype @@ -301,16 +294,10 @@ def replacement(input: torch.Tensor, residual: torch.Tensor, # result, residual, scale return at[1], at[3], at[2] - inputs = [ - empty_bf16(5, 4), # input - empty_bf16(5, 4), # residual - empty_bf16(4), # weight - ] - pm.register_replacement( pattern, replacement, - inputs, + self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass, ) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 9cde9230211f..a72e7396f526 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from abc import ABC, abstractmethod +from typing import Optional import torch from torch._higher_order_ops import auto_functionalized @@ -31,55 +32,71 @@ # kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 -class MatcherRMSNorm: # TODO separate residual and not residual +class MatcherCustomOp(ABC): - def __init__(self, epsilon: float, enabled: Optional[bool] = None): - self.epsilon = epsilon + def __init__(self, enabled: bool): + self.model_dtype = get_current_vllm_config().model_config.dtype + + self.enabled = enabled + self.forward = self.forward_custom if enabled else self.forward_native + + @abstractmethod + def forward_custom(self, *args, **kws): + pass + + @abstractmethod + def forward_native(self, *args, **kws): + pass + def __call__(self, *args, **kws): + return self.forward(*args, **kws) + + def empty(self, *args, **kws): + return torch.empty(*args, dtype=self.model_dtype, device="cuda", **kws) + + def empty_f32(self, *args, **kws): + return torch.empty(*args, dtype=torch.float32, device="cuda", **kws) + + +class MatcherRMSNorm(MatcherCustomOp): + + def __init__(self, epsilon: float, enabled: Optional[bool] = None): if enabled is None: # TODO either pass config to enabled or set it globally # (global during pass init seems reasonable) enabled = RMSNorm.enabled() - self.forward = self.forward_custom if enabled else self.forward_native - self.model_dtype = get_current_vllm_config().model_config.dtype - print(self.model_dtype) + super().__init__(enabled) + self.epsilon = epsilon def inputs(self): - return + input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) + weight = self.empty(16, ) + return [input, weight] def forward_custom( self, input: torch.Tensor, weight: torch.Tensor, residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - if residual is None: - result = torch.empty_like(input) - _, result = auto_functionalized( - RMS_OP, - result=result, - input=input, - weight=weight, - epsilon=self.epsilon, - ) - - return result - else: - _, result, residual = auto_functionalized(RMS_ADD_OP, - input=input, - residual=residual, - weight=weight, - epsilon=self.epsilon) + ) -> torch.Tensor: + result = torch.empty_like(input) + _, result = auto_functionalized( + RMS_OP, + result=result, + input=input, + weight=weight, + epsilon=self.epsilon, + ) - return result, residual + return result def forward_native( self, input: torch.Tensor, weight: torch.Tensor, residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + ) -> torch.Tensor: x = input.to(torch.float32) if residual is not None: x = x + residual @@ -94,13 +111,57 @@ def forward_native( return x if residual is None else (x, residual) - def __call__( + +class MatcherFusedAddRMSNorm(MatcherCustomOp): + + def __init__(self, epsilon: float, enabled: Optional[bool] = None): + if enabled is None: + # TODO either pass config to enabled or set it globally + # (global during pass init seems reasonable) + enabled = RMSNorm.enabled() + + super().__init__(enabled) + self.epsilon = epsilon + + def inputs(self): + input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) + weight = self.empty(16, ) + residual = self.empty(5, 16) + return [input, weight, residual] + + def forward_custom( self, input: torch.Tensor, weight: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - return self.forward(input, weight, residual) + residual: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + _, result, residual = auto_functionalized(RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon) + + return result, residual + + def forward_native( + self, + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + x = input.to(torch.float32) + if residual is not None: + x = x + residual + residual = x.to(self.model_dtype) + + variance = x.pow(2).mean(dim=-1, keepdim=True) + + x = x * torch.rsqrt(variance + self.epsilon) + x = x.to(self.model_dtype) + if weight is not None: + x = x * weight + + return x if residual is None else (x, residual) class MatcherQuant: From e6b394e28a10cd19e2db7af52686591bc30599a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 19 Sep 2025 19:00:27 -0700 Subject: [PATCH 008/137] Add TODO MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 4a9a497989e8..edda51e2844a 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -64,7 +64,8 @@ def forward(self, x): input_scale=self.scale[0]) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - + # TODO another fp8 linear + rmsnorm to make sure fusion + # works for residual output as well x3 = self.fp8_linear.apply(y2, self.w[1], self.wscale[1], From d96913a7987a6747eef7cfddc05c45c1220312b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 25 Sep 2025 16:06:25 -0400 Subject: [PATCH 009/137] Cleanup test_fusion.py, added extra layer of rms/quant MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion.py | 52 +++++++++++------------------------- 1 file changed, 15 insertions(+), 37 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index edda51e2844a..3b494fce3bae 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -5,8 +5,7 @@ import torch import vllm.plugins -from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, RMS_OP, - FusedRMSQuantKey, RMSNormQuantFusionPass) +from vllm.compilation.fusion import RMSNormQuantFusionPass from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import (CompilationConfig, CompilationLevel, ModelConfig, @@ -30,18 +29,18 @@ def __init__(self, hidden_size: int, eps: float, static: bool, cuda_force_torch: bool, *args, **kwargs): super().__init__(*args, **kwargs) self.cuda_force_torch = cuda_force_torch - self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] - self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] + self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] + self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN quant_scale = ScaleDesc(torch.float32, static, group_shape) self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) if static: - self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] + self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] else: - self.scale = [None for _ in range(2)] + self.scale = [None for _ in range(3)] self.w = [ torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() - for _ in range(2) + for _ in range(3) ] with override_cutlass_fp8_supported(not cuda_force_torch): @@ -64,34 +63,21 @@ def forward(self, x): input_scale=self.scale[0]) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - # TODO another fp8 linear + rmsnorm to make sure fusion - # works for residual output as well + x3 = self.fp8_linear.apply(y2, self.w[1], self.wscale[1], input_scale=self.scale[1]) - y3, resid = self.norm[2](x3, resid) # use resid here - return y3 - def ops_in_model_before(self): - ops = [] - if self.enable_rms_norm: - ops += [RMS_OP] - else: - ops += [torch.ops.aten.rsqrt.default] - - if self.enable_quant_fp8: - ops += [QUANT_OPS[self.key]] - else: - ops += [torch.ops.aten.reciprocal.default] + y3, resid = self.norm[2](x3, resid) # use resid here - return ops + x4 = self.fp8_linear.apply(y3, + self.w[2], + self.wscale[2], + input_scale=self.scale[2]) - def ops_in_model_after(self): - return [ - FUSED_OPS[FusedRMSQuantKey(self.key, False)], - FUSED_OPS[FusedRMSQuantKey(self.key, True)] - ] + y4, resid = self.norm[3](x4, resid) # use resid here + return y4 @pytest.mark.parametrize("dtype", [torch.float16]) #, torch.bfloat16]) @@ -123,8 +109,6 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, vllm_config = VllmConfig( model_config=ModelConfig(dtype=dtype), compilation_config=CompilationConfig( - debug_dump_path=f"/home/luka/git/vllm/._workspace/" - f"debug_dump_{enable_rms_norm}_{enable_quant_fp8}", level=CompilationLevel.PIECEWISE, custom_ops=custom_ops, pass_config=PassConfig(enable_fusion=True, enable_noop=True), @@ -156,10 +140,4 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) - assert fusion_pass.matched_count == 2 - - # In pre-nodes, fp8 quant should be there and fused kernels should not - # backend.check_before_ops(model.ops_in_model_before()) - - # In post-nodes, fused kernels should be there and fp8 quant should not - # backend.check_after_ops(model.ops_in_model_after()) + assert fusion_pass.matched_count == 3 From b1727475027f129549da400e41e46bb4e4e045a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 25 Sep 2025 15:02:33 -0700 Subject: [PATCH 010/137] Functionalize attn+quant patterns MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/backend.py | 6 +- tests/compile/test_fusion.py | 77 ++++++--- vllm/compilation/fusion.py | 275 ++++++++++++++++-------------- vllm/compilation/fusion_attn.py | 54 ++++-- vllm/compilation/matcher_utils.py | 83 +++++---- 5 files changed, 281 insertions(+), 214 deletions(-) diff --git a/tests/compile/backend.py b/tests/compile/backend.py index 113906af0203..fb92fd7b42a5 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -56,8 +56,10 @@ def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]): self.inductor_config["post_grad_custom_post_pass"] = self.post_pass if compile_config.debug_dump_path: - self.debug_dump_path = (Path(compile_config.debug_dump_path) / - f"rank_{vllm_config.parallel_config.rank}") + self.debug_dump_path = ( + Path(compile_config.debug_dump_path) + / f"rank_{vllm_config.parallel_config.rank}" + ) self.ctx = depyf.prepare_debug(str(self.debug_dump_path)) self.ctx.__enter__() else: diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 3b494fce3bae..13cffbe087c6 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -8,13 +8,24 @@ from vllm.compilation.fusion import RMSNormQuantFusionPass from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass -from vllm.config import (CompilationConfig, CompilationLevel, ModelConfig, - PassConfig, VllmConfig) +from vllm.config import ( + CompilationConfig, + CompilationLevel, + ModelConfig, + PassConfig, + VllmConfig, +) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, QuantKey, ScaleDesc) + GroupShape, + QuantKey, + ScaleDesc, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity) + Fp8LinearOp, + cutlass_fp8_supported, + maybe_create_device_identity, +) from vllm.platforms import current_platform from ..utils import override_cutlass_fp8_supported @@ -24,9 +35,15 @@ class TestModel(torch.nn.Module): - - def __init__(self, hidden_size: int, eps: float, static: bool, - cuda_force_torch: bool, *args, **kwargs): + def __init__( + self, + hidden_size: int, + eps: float, + static: bool, + cuda_force_torch: bool, + *args, + **kwargs, + ): super().__init__(*args, **kwargs) self.cuda_force_torch = cuda_force_torch self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] @@ -57,30 +74,27 @@ def forward(self, x): x = resid = torch.relu(x) y = self.norm[0](x) - x2 = self.fp8_linear.apply(y, - self.w[0], - self.wscale[0], - input_scale=self.scale[0]) + x2 = self.fp8_linear.apply( + y, self.w[0], self.wscale[0], input_scale=self.scale[0] + ) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - x3 = self.fp8_linear.apply(y2, - self.w[1], - self.wscale[1], - input_scale=self.scale[1]) + x3 = self.fp8_linear.apply( + y2, self.w[1], self.wscale[1], input_scale=self.scale[1] + ) y3, resid = self.norm[2](x3, resid) # use resid here - x4 = self.fp8_linear.apply(y3, - self.w[2], - self.wscale[2], - input_scale=self.scale[2]) + x4 = self.fp8_linear.apply( + y3, self.w[2], self.wscale[2], input_scale=self.scale[2] + ) y4, resid = self.norm[3](x4, resid) # use resid here return y4 -@pytest.mark.parametrize("dtype", [torch.float16]) #, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float16]) # , torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @@ -89,13 +103,22 @@ def forward(self, x): @pytest.mark.parametrize("enable_quant_fp8", [True, False]) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. -@pytest.mark.parametrize("cuda_force_torch", - [True, False] if cutlass_fp8_supported() else [True]) -@pytest.mark.skipif(not current_platform.is_cuda_alike(), - reason="Only test on CUDA and ROCm") -def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, - enable_rms_norm, enable_quant_fp8, - cuda_force_torch): +@pytest.mark.parametrize( + "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True] +) +@pytest.mark.skipif( + not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm" +) +def test_fusion_rmsnorm_quant( + dtype, + hidden_size, + num_tokens, + eps, + static, + enable_rms_norm, + enable_quant_fp8, + cuda_force_torch, +): torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 4e1b569f77e0..742e5355d1cf 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -12,8 +12,15 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym, - kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale) + GroupShape, + QuantKey, + ScaleDesc, + kFp8DynamicTensorSym, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, + kNvfp4Quant, + kStaticTensorScale, +) from vllm.platforms import current_platform from .inductor_pass import enable_fake_mode @@ -41,12 +48,9 @@ def empty_i32(*args, **kwargs): RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default QUANT_OPS: dict[QuantKey, OpOverload] = { - kFp8StaticTensorSym: - torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTensorSym: - torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTokenSym: - torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 + kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default @@ -58,77 +62,82 @@ class FusedRMSQuantKey(NamedTuple): quant: type of quantization fused_add: does the op also perform the residual add """ + quant: QuantKey fused_add: bool def __str__(self): - return (f"FusedQuantKey({self.quant}, with" - f"{'' if self.fused_add else 'out'} residual)") + return ( + f"FusedQuantKey({self.quant}, with" + f"{'' if self.fused_add else 'out'} residual)" + ) FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { - FusedRMSQuantKey(kFp8StaticTensorSym, False): - torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501 - FusedRMSQuantKey(kFp8StaticTensorSym, True): - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501 - FusedRMSQuantKey(kFp8DynamicTokenSym, False): - torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 - FusedRMSQuantKey(kFp8DynamicTokenSym, True): - torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8StaticTensorSym, False + ): torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8StaticTensorSym, True + ): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8DynamicTokenSym, False + ): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8DynamicTokenSym, True + ): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 } class RMSNormQuantPattern: - def __init__(self, epsilon: float, key: FusedRMSQuantKey): self.epsilon = epsilon self.quant_dtype = key.quant.dtype - assert key.quant in QUANT_OPS, \ - f"unsupported quantization scheme {key.quant}" + assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}" self.QUANT_OP = QUANT_OPS[key.quant] - assert key in FUSED_OPS, \ - f"unsupported fused rmsnorm+quant op for {key}" + assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" self.FUSED_OP = FUSED_OPS[key] - self.rmsnorm_matcher = MatcherRMSNorm(epsilon) if not key.fused_add \ + self.rmsnorm_matcher = ( + MatcherRMSNorm(epsilon) + if not key.fused_add else MatcherFusedAddRMSNorm(epsilon) + ) self.quant_matcher = MatcherQuant(key.quant) class RMSNormStaticQuantPattern(RMSNormQuantPattern): - - def __init__(self, - epsilon: float, - quant_dtype: torch.dtype, - symmetric=True): - fused_key = FusedRMSQuantKey(fused_add=False, - quant=QuantKey(dtype=quant_dtype, - scale=kStaticTensorScale, - symmetric=symmetric)) + def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): + fused_key = FusedRMSQuantKey( + fused_add=False, + quant=QuantKey( + dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric + ), + ) super().__init__(epsilon, fused_key) def register(self, pm_pass: PatternMatcherPass): # Cannot use methods, as the self argument affects tracing - def pattern(input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): + def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): result_rms = self.rmsnorm_matcher(input, weight) return self.quant_matcher(result_rms, scale)[0] - def replacement(input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): + def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. input = input.to(dtype=torch.float16) # TODO model dtype result = torch.empty_like(input, dtype=self.quant_dtype) - at = auto_functionalized(self.FUSED_OP, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=self.epsilon) + at = auto_functionalized( + self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + ) # result return at[1] @@ -136,51 +145,56 @@ def replacement(input: torch.Tensor, weight: torch.Tensor, inputs = [ # input, weight *self.rmsnorm_matcher.inputs(), - empty_fp32(1, 1) # scale + empty_fp32(1, 1), # scale ] pattern(*inputs) - pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, - pm_pass) + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): - - def __init__(self, - epsilon: float, - quant_dtype: torch.dtype, - symmetric=True): - key = FusedRMSQuantKey(fused_add=True, - quant=QuantKey(dtype=quant_dtype, - scale=kStaticTensorScale, - symmetric=symmetric)) + def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): + key = FusedRMSQuantKey( + fused_add=True, + quant=QuantKey( + dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric + ), + ) super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - - def pattern(input: torch.Tensor, weight: torch.Tensor, - residual: torch.Tensor, scale: torch.Tensor): - result_rms, residual = self.rmsnorm_matcher( - input, weight, residual) + def pattern( + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + scale: torch.Tensor, + ): + result_rms, residual = self.rmsnorm_matcher(input, weight, residual) result, _ = self.quant_matcher(result_rms, scale) return result, residual - def replacement(input: torch.Tensor, weight: torch.Tensor, - residual: torch.Tensor, scale: torch.Tensor): + def replacement( + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + scale: torch.Tensor, + ): # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. input = input.to(dtype=torch.float16) # TODO model dtype residual = residual.to(dtype=torch.float16) result = torch.empty_like(input, dtype=self.quant_dtype) - at = auto_functionalized(self.FUSED_OP, - result=result, - input=input, - residual=residual, - weight=weight, - scale=scale, - epsilon=self.epsilon) + at = auto_functionalized( + self.FUSED_OP, + result=result, + input=input, + residual=residual, + weight=weight, + scale=scale, + epsilon=self.epsilon, + ) # result, residual return at[1], at[2] @@ -188,7 +202,7 @@ def replacement(input: torch.Tensor, weight: torch.Tensor, inputs = [ # input, weight, residual *self.rmsnorm_matcher.inputs(), - empty_fp32(1, 1) # scale + empty_fp32(1, 1), # scale ] pm.register_replacement( @@ -201,21 +215,21 @@ def replacement(input: torch.Tensor, weight: torch.Tensor, class RMSNormDynamicQuantPattern(RMSNormQuantPattern): - - def __init__(self, - epsilon: float, - quant_dtype: torch.dtype, - group_shape: GroupShape = GroupShape.PER_TOKEN, - symmetric=True): + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape = GroupShape.PER_TOKEN, + symmetric=True, + ): scale = ScaleDesc(torch.float32, False, group_shape) - key = FusedRMSQuantKey(fused_add=False, - quant=QuantKey(dtype=quant_dtype, - scale=scale, - symmetric=symmetric)) + key = FusedRMSQuantKey( + fused_add=False, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - def pattern(input: torch.Tensor, weight: torch.Tensor): result_rms = self.rmsnorm_matcher(input, weight) # result, scale @@ -228,14 +242,16 @@ def replacement(input: torch.Tensor, weight: torch.Tensor): result = torch.empty_like(input, dtype=self.quant_dtype) scale = self.quant_matcher.make_scale(input) - at = auto_functionalized(self.FUSED_OP, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=self.epsilon, - scale_ub=None, - residual=None) + at = auto_functionalized( + self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=None, + ) # result, scale return at[1], at[2] @@ -250,31 +266,30 @@ def replacement(input: torch.Tensor, weight: torch.Tensor): class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): - - def __init__(self, - epsilon: float, - quant_dtype: torch.dtype, - group_shape: GroupShape = GroupShape.PER_TOKEN, - symmetric=True): + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape = GroupShape.PER_TOKEN, + symmetric=True, + ): scale = ScaleDesc(torch.float32, False, group_shape) - key = FusedRMSQuantKey(fused_add=True, - quant=QuantKey(dtype=quant_dtype, - scale=scale, - symmetric=symmetric)) + key = FusedRMSQuantKey( + fused_add=True, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - - def pattern(input: torch.Tensor, weight: torch.Tensor, - residual: torch.Tensor): - result_rms, residual = self.rmsnorm_matcher( - input, weight, residual) + def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor): + result_rms, residual = self.rmsnorm_matcher(input, weight, residual) result, scale = self.quant_matcher(result_rms) return result, residual, scale - def replacement(input: torch.Tensor, weight: torch.Tensor, - residual: torch.Tensor): + def replacement( + input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor + ): # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. input = input.to(dtype=torch.float16) # TODO model dtype @@ -282,14 +297,16 @@ def replacement(input: torch.Tensor, weight: torch.Tensor, result = torch.empty_like(input, dtype=self.quant_dtype) scale = self.quant_matcher.make_scale(input) - at = auto_functionalized(self.FUSED_OP, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=self.epsilon, - scale_ub=None, - residual=residual) + at = auto_functionalized( + self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=residual, + ) # result, residual, scale return at[1], at[3], at[2] @@ -314,25 +331,26 @@ def __init__(self, config: VllmConfig): super().__init__(config) self.patterns: PatternMatcherPass = PatternMatcherPass( - pass_name="rmsnorm_quant_fusion_pass") + pass_name="rmsnorm_quant_fusion_pass" + ) with set_current_vllm_config(config, check_compile=False): for epsilon in [1e-5, 1e-6]: # Fuse rms_norm + static fp8 quant - RMSNormStaticQuantPattern(epsilon, - FP8_DTYPE).register(self.patterns) + RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) # Fuse fused_add_rms_norm + static fp8 quant FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns) + self.patterns + ) # Fuse rms_norm + dynamic per-token fp8 quant - RMSNormDynamicQuantPattern(epsilon, - FP8_DTYPE).register(self.patterns) + RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) # Fuse fused_add_rms_norm + dynamic per-token fp8 quant - FusedAddRMSNormDynamicQuantPattern( - epsilon, FP8_DTYPE).register(self.patterns) + FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( + self.patterns + ) self.dump_patterns(config, self.patterns) @@ -342,8 +360,11 @@ def __call__(self, graph: fx.Graph): logger.debug("Replaced %s patterns", self.matched_count) def uuid(self) -> Any: - return self.hash_source(self, RMSNormQuantPattern, - RMSNormStaticQuantPattern, - RMSNormDynamicQuantPattern, - FusedAddRMSNormStaticQuantPattern, - FusedAddRMSNormDynamicQuantPattern) + return self.hash_source( + self, + RMSNormQuantPattern, + RMSNormStaticQuantPattern, + RMSNormDynamicQuantPattern, + FusedAddRMSNormStaticQuantPattern, + FusedAddRMSNormDynamicQuantPattern, + ) diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index ae36cef92653..6933442552aa 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -2,9 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from typing import Callable import torch import torch._inductor.pattern_matcher as pm +from torch import fx from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass @@ -20,7 +22,9 @@ from vllm.utils import round_up from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 +from .fx_utils import is_func from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherQuant from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -66,9 +70,13 @@ def empty_quant(self, *args, **kwargs): return torch.empty(*args, **kwargs) @staticmethod - def wrap_trace_fn(process_fx, trace_fn): + def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]): def wrapped(*args, **kwargs): - return process_fx(trace_fn(*args, **kwargs)) + gm = trace_fn(*args, **kwargs) + for process_fx in process_fx_fns: + process_fx(gm) + + return gm return wrapped @@ -77,7 +85,20 @@ def fx_view_to_reshape(gm: torch.fx.GraphModule): from torch._inductor.fx_passes.post_grad import view_to_reshape view_to_reshape(gm) - return gm + + @staticmethod + def remove_noop_permutes(gm: torch.fx.GraphModule): + for node in gm.graph.nodes: + if not is_func(node, torch.ops.aten.permute.default): + continue + + dims = node.args[1] + if any(dim != i for i, dim in enumerate(dims)): + continue + + # this is now an identity op, remove + node.replace_all_uses_with(node.args[0]) + gm.graph.erase_node(node) def register_if_supported(self, pm_pass: PatternMatcherPass): if self.layer.impl.fused_output_quant_supported(self.quant_key): @@ -108,6 +129,7 @@ def __init__( dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric ) super().__init__(layer, quant_key, dtype) + self.quant_matcher = MatcherQuant(quant_key) def _register(self, pm_pass: PatternMatcherPass): def pattern( @@ -115,7 +137,6 @@ def pattern( k: torch.Tensor, v: torch.Tensor, output_attn: torch.Tensor, - output_quant: torch.Tensor, scale: torch.Tensor, ): at1 = auto_functionalized( @@ -131,6 +152,11 @@ def pattern( attn_out_view = RESHAPE_OP( at1[1], [q.shape[0], self.num_heads * self.head_size] ) + output_quant = torch.empty( + attn_out_view.size(), + device=attn_out_view.device, + dtype=self.quant_dtype, + ) at2 = auto_functionalized( self.QUANT_OP, result=output_quant, input=attn_out_view, scale=scale ) @@ -141,7 +167,6 @@ def replacement( k: torch.Tensor, v: torch.Tensor, output_attn: torch.Tensor, - output_quant: torch.Tensor, scale: torch.Tensor, ): # attn output in quant_dtype @@ -164,13 +189,10 @@ def replacement( return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) inputs = [ - self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # q - self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # k - self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # v - self.empty( - 5, self.num_heads, self.head_size, dtype=self.dtype - ), # attn_output - self.empty_quant(5, self.num_heads * self.head_size), # quant_output + self.empty(5, self.num_heads, self.head_size), # q + self.empty(5, self.num_heads, self.head_size), # k + self.empty(5, self.num_heads, self.head_size), # v + self.empty(5, self.num_heads, self.head_size), # attn_output empty_fp32(1, 1), # scale ] @@ -179,7 +201,9 @@ def replacement( replacement, inputs, AttentionQuantPattern.wrap_trace_fn( - AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only + pm.fwd_only, + AttentionQuantPattern.fx_view_to_reshape, + AttentionQuantPattern.remove_noop_permutes, ), pm_pass, ) @@ -279,7 +303,9 @@ def replacement( replacement, inputs, AttentionQuantPattern.wrap_trace_fn( - AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only + pm.fwd_only, + AttentionQuantPattern.fx_view_to_reshape, + AttentionQuantPattern.remove_noop_permutes, ), pm_pass, ) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index a72e7396f526..d3603372d69f 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -11,19 +11,20 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, _normalize_quant_group_shape, kFp8DynamicTensorSym, - kFp8DynamicTokenSym, kFp8StaticTensorSym) + QuantKey, + _normalize_quant_group_shape, + kFp8DynamicTensorSym, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, +) RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default QUANT_OPS: dict[QuantKey, OpOverload] = { - kFp8StaticTensorSym: - torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTensorSym: - torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTokenSym: - torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 + kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } # TODO @@ -33,7 +34,6 @@ class MatcherCustomOp(ABC): - def __init__(self, enabled: bool): self.model_dtype = get_current_vllm_config().model_config.dtype @@ -59,7 +59,6 @@ def empty_f32(self, *args, **kws): class MatcherRMSNorm(MatcherCustomOp): - def __init__(self, epsilon: float, enabled: Optional[bool] = None): if enabled is None: # TODO either pass config to enabled or set it globally @@ -71,7 +70,9 @@ def __init__(self, epsilon: float, enabled: Optional[bool] = None): def inputs(self): input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) - weight = self.empty(16, ) + weight = self.empty( + 16, + ) return [input, weight] def forward_custom( @@ -113,7 +114,6 @@ def forward_native( class MatcherFusedAddRMSNorm(MatcherCustomOp): - def __init__(self, epsilon: float, enabled: Optional[bool] = None): if enabled is None: # TODO either pass config to enabled or set it globally @@ -125,7 +125,9 @@ def __init__(self, epsilon: float, enabled: Optional[bool] = None): def inputs(self): input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) - weight = self.empty(16, ) + weight = self.empty( + 16, + ) residual = self.empty(5, 16) return [input, weight, residual] @@ -135,11 +137,13 @@ def forward_custom( weight: torch.Tensor, residual: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - _, result, residual = auto_functionalized(RMS_ADD_OP, - input=input, - residual=residual, - weight=weight, - epsilon=self.epsilon) + _, result, residual = auto_functionalized( + RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon, + ) return result, residual @@ -165,17 +169,13 @@ def forward_native( class MatcherQuant: - def __init__(self, quant_key: QuantKey, enabled: Optional[bool] = None): - self.quant_key = quant_key - assert quant_key in QUANT_OPS, \ - f"unsupported quantization scheme {quant_key}" + assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}" self.QUANT_OP = QUANT_OPS[quant_key] assert quant_key.scale2 is None - self.quant_fp8 = QuantFP8(quant_key.scale.static, - quant_key.scale.group_shape) + self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape) if enabled is None: # TODO either pass config to enabled or set it globally @@ -191,25 +191,22 @@ def forward_custom( ) -> tuple[torch.Tensor, torch.Tensor]: # TODO: why does empty_like produce a permute but # empty via shape doesn't? - result = torch.empty(input.shape, - device=input.device, - dtype=self.quant_key.dtype) + result = torch.empty( + input.shape, device=input.device, dtype=self.quant_key.dtype + ) if self.quant_key.scale.static: assert scale is not None - _, result = auto_functionalized(self.QUANT_OP, - result=result, - input=input, - scale=scale) + _, result = auto_functionalized( + self.QUANT_OP, result=result, input=input, scale=scale + ) return result, scale else: assert scale is None scale = self.make_scale(input) - _, result, scale = auto_functionalized(self.QUANT_OP, - result=result, - input=input, - scale=scale, - scale_ub=None) + _, result, scale = auto_functionalized( + self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None + ) return result, scale def forward_native( @@ -221,18 +218,16 @@ def forward_native( def make_scale(self, input: torch.Tensor): normalized_group_shape = _normalize_quant_group_shape( - input, self.quant_key.scale.group_shape) + input, self.quant_key.scale.group_shape + ) scale_shape = ( input.shape[0] // normalized_group_shape[0], input.shape[1] // normalized_group_shape[1], ) - return torch.empty(scale_shape, - device=input.device, - dtype=torch.float32) + return torch.empty(scale_shape, device=input.device, dtype=torch.float32) - def __call__(self, - input: torch.Tensor, - scale: Optional[torch.Tensor] = None - ) -> tuple[torch.Tensor, torch.Tensor]: + def __call__( + self, input: torch.Tensor, scale: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor]: return self.forward(input, scale) From 1ae80c6fff346994a199a358a6a89821e3890ce0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 25 Sep 2025 16:02:21 -0700 Subject: [PATCH 011/137] Move global vllm_config to pass manager MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/compilation/fusion.py | 29 ++++++++++++------------- vllm/compilation/pass_manager.py | 37 +++++++++++++++++--------------- 2 files changed, 34 insertions(+), 32 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 742e5355d1cf..883743b635a8 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -9,7 +9,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from torch._ops import OpOverload -from vllm.config import VllmConfig, set_current_vllm_config +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, @@ -334,23 +334,22 @@ def __init__(self, config: VllmConfig): pass_name="rmsnorm_quant_fusion_pass" ) - with set_current_vllm_config(config, check_compile=False): - for epsilon in [1e-5, 1e-6]: - # Fuse rms_norm + static fp8 quant - RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) + for epsilon in [1e-5, 1e-6]: + # Fuse rms_norm + static fp8 quant + RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) - # Fuse fused_add_rms_norm + static fp8 quant - FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns - ) + # Fuse fused_add_rms_norm + static fp8 quant + FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( + self.patterns + ) - # Fuse rms_norm + dynamic per-token fp8 quant - RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) + # Fuse rms_norm + dynamic per-token fp8 quant + RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) - # Fuse fused_add_rms_norm + dynamic per-token fp8 quant - FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns - ) + # Fuse fused_add_rms_norm + dynamic per-token fp8 quant + FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( + self.patterns + ) self.dump_patterns(config, self.patterns) diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index e323fa1f7734..3d7c6287fe07 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -5,7 +5,7 @@ from torch import fx as fx from vllm import envs -from vllm.config import VllmConfig +from vllm.config import VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import set_env_var @@ -86,27 +86,30 @@ def __call__(self, graph: fx.Graph): def configure(self, config: VllmConfig): self.pass_config = config.compilation_config.pass_config - if self.pass_config.enable_noop: - self.passes += [NoOpEliminationPass(config)] - if self.pass_config.enable_sequence_parallelism: - self.passes += [SequenceParallelismPass(config)] - if self.pass_config.enable_async_tp: - self.passes += [AsyncTPPass(config)] + # Set the current vllm config to allow tracing CustomOp instances + with set_current_vllm_config(config, check_compile=False): + if self.pass_config.enable_noop: + self.passes += [NoOpEliminationPass(config)] - if self.pass_config.enable_fi_allreduce_fusion: - self.passes += [AllReduceFusionPass(config)] + if self.pass_config.enable_sequence_parallelism: + self.passes += [SequenceParallelismPass(config)] + if self.pass_config.enable_async_tp: + self.passes += [AsyncTPPass(config)] - if self.pass_config.enable_fusion: - self.passes += [RMSNormQuantFusionPass(config)] - self.passes += [ActivationQuantFusionPass(config)] + if self.pass_config.enable_fi_allreduce_fusion: + self.passes += [AllReduceFusionPass(config)] - if self.pass_config.enable_attn_fusion: - self.passes += [AttnFusionPass(config)] + if self.pass_config.enable_fusion: + self.passes += [RMSNormQuantFusionPass(config)] + self.passes += [ActivationQuantFusionPass(config)] - # needs a functional graph - self.post_cleanup = PostCleanupPass(config) - self.fix_functionalization = FixFunctionalizationPass(config) + if self.pass_config.enable_attn_fusion: + self.passes += [AttnFusionPass(config)] + + # needs a functional graph + self.post_cleanup = PostCleanupPass(config) + self.fix_functionalization = FixFunctionalizationPass(config) def add(self, pass_: InductorPass): assert isinstance(pass_, InductorPass) From 77835fd36531b2b88f591182dee4e61e9cd9639e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 25 Sep 2025 16:12:11 -0700 Subject: [PATCH 012/137] Attention fusion works with custom ops MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_attn.py | 80 ++++++++++++++++++------------- vllm/compilation/fusion_attn.py | 11 +---- 2 files changed, 48 insertions(+), 43 deletions(-) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 0f2e3bffbd31..5b6b7dcfe8f1 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -12,7 +12,6 @@ from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.registry import _Backend from vllm.attention.selector import global_force_attn_backend_context_manager -from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.noop_elimination import NoOpEliminationPass @@ -242,26 +241,49 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): ) +MODELS_FP8 = [] +MODELS_FP4 = [] +HEADS = [] +SPLIT_ATTENTION = [] +BACKENDS: list[_Backend] = [] + if current_platform.is_cuda(): - MODELS = [ + MODELS_FP8 = [ ( "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", TestAttentionFp8StaticQuantPatternModel, - ), - ( - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", - TestAttentionNvfp4QuantPatternModel, - ), + ) ] HEADS = [(64, 8), (40, 8)] + SPLIT_ATTENTION = [False] + BACKENDS = [] # TODO [_Backend.TRITON_ATTN] + + if current_platform.is_device_capability((10, 0)): + BACKENDS += [_Backend.FLASHINFER] + MODELS_FP4 += [ + ( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + TestAttentionNvfp4QuantPatternModel, + ) + ] + elif current_platform.is_rocm(): - MODELS = [ + MODELS_FP8 = [ ("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel) ] HEADS = [(32, 8), (40, 8)] + SPLIT_ATTENTION = [False, True] + BACKENDS = [ + _Backend.TRITON_ATTN, + _Backend.ROCM_AITER_UNIFIED_ATTN, + _Backend.ROCM_ATTN, + ] + +# TODO(boyuan/luka): test inductor graph partition on rocm +if is_torch_equal_or_newer("2.9.0.dev") and current_platform.is_cuda(): + USE_INDUCTOR_GRAPH_PARTITION = [False, True] else: - MODELS = [] - HEADS = [] + USE_INDUCTOR_GRAPH_PARTITION = [False] @pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS) @@ -270,35 +292,26 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): "batch_size", [7, 256, 533] if current_platform.is_cuda() else [8] ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("model_name, model_class", MODELS) @pytest.mark.parametrize( - "backend", - [_Backend.FLASHINFER] - if current_platform.is_cuda() - else [_Backend.ROCM_AITER_UNIFIED_ATTN, _Backend.ROCM_ATTN, _Backend.TRITON_ATTN], -) -# TODO(boyuan): test inductor graph partition on rocm -@pytest.mark.parametrize( - "use_inductor_graph_partition", - [False] if current_platform.is_rocm() else [False, True], + "model_name, model_class, custom_ops", + # Test attention+quant_fp8 fusion with custom and torch impls + [(*model, c) for model in MODELS_FP8 for c in ["+quant_fp8", "-quant_fp8"]] + # quant_fp4 only has the custom impl + + [(*model, c) for model in MODELS_FP4 for c in [""]], ) +@pytest.mark.parametrize("backend", BACKENDS) +@pytest.mark.parametrize("use_inductor_graph_partition", USE_INDUCTOR_GRAPH_PARTITION) @pytest.mark.skipif( not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA" ) @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") -@pytest.mark.skipif( - current_platform.is_cuda() and not current_platform.is_device_capability((10, 0)), - reason="On CUDA only test on SM100(Blackwell)", -) -@pytest.mark.skipif( - not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA" -) def test_attention_quant_pattern( num_qo_heads: int, num_kv_heads: int, head_size: int, batch_size: int, dtype: torch.dtype, + custom_ops: str, model_name: str, model_class: type[AttentionQuantPatternModel], backend: _Backend, @@ -308,8 +321,7 @@ def test_attention_quant_pattern( ): """Test AttentionStaticQuantPattern fusion pass""" - if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): - pytest.skip("inductor graph partition is only available in PyTorch 2.9+") + custom_ops_list = custom_ops.split(",") if custom_ops else [] device = torch.device("cuda:0") torch.manual_seed(42) @@ -323,7 +335,7 @@ def test_attention_quant_pattern( scheduler_config=SchedulerConfig(max_num_seqs=1024), compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, - custom_ops=["+quant_fp8"], + custom_ops=custom_ops_list, use_inductor_graph_partition=use_inductor_graph_partition, ), cache_config=CacheConfig(cache_dtype="fp8"), @@ -420,12 +432,12 @@ def test_attention_quant_pattern( layer.impl.fused_output_quant_supported(quant_key) for key, layer in vllm_config.compilation_config.static_forward_context.items() ] - if any(attn_fusion_supported): - # Check quantization ops in the graph before and after fusion - test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True) + assert sum(attn_fusion_supported) == len(attn_fusion_supported), ( + "All layers should support attention fusion" + ) # access the underlying `AttnFusionPass` on the `LazyInitPass` - assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) + assert attn_pass.pass_.matched_count == 1 # Check attention ops in the graph before and after fusion attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass)) diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index 6933442552aa..761acb35834b 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -152,15 +152,8 @@ def pattern( attn_out_view = RESHAPE_OP( at1[1], [q.shape[0], self.num_heads * self.head_size] ) - output_quant = torch.empty( - attn_out_view.size(), - device=attn_out_view.device, - dtype=self.quant_dtype, - ) - at2 = auto_functionalized( - self.QUANT_OP, result=output_quant, input=attn_out_view, scale=scale - ) - return at2[1] + + return self.quant_matcher(attn_out_view, scale)[0] def replacement( q: torch.Tensor, From 1277999c297cf1fcc784ed3a1f698284e9f63cf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 25 Sep 2025 16:12:23 -0700 Subject: [PATCH 013/137] Remove V0 attn fusion test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_attn.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 5b6b7dcfe8f1..c91e162c8e74 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy -from typing import Optional import pytest import torch._dynamo @@ -39,10 +38,6 @@ FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 -# globals needed for string-import custom Dynamo backend field -backend: Optional[TestBackend] = None -backend_unfused: Optional[TestBackend] = None - class AttentionQuantPatternModel(torch.nn.Module): """Base model for AttentionQuantPattern fusion.""" From d843a67c428ae6e1c4397d24377a206b86cbb6cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 25 Sep 2025 17:02:14 -0700 Subject: [PATCH 014/137] Add triton attn test to attn+quant fusion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_attn.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index c91e162c8e74..4d6cdabf6a90 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy +import itertools import pytest import torch._dynamo @@ -99,6 +100,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: num_blocks = batch_size * max_blocks backend = self.attn.backend + # TODO use get_kv_cache_stride_order # Create dummy KV cache for the selected backend if backend == _Backend.ROCM_ATTN: # k/v as 1st dimention @@ -240,7 +242,8 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): MODELS_FP4 = [] HEADS = [] SPLIT_ATTENTION = [] -BACKENDS: list[_Backend] = [] +BACKENDS_FP8: list[_Backend] = [] +BACKENDS_FP4: list[_Backend] = [] if current_platform.is_cuda(): MODELS_FP8 = [ @@ -251,10 +254,11 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): ] HEADS = [(64, 8), (40, 8)] SPLIT_ATTENTION = [False] - BACKENDS = [] # TODO [_Backend.TRITON_ATTN] + BACKENDS_FP8 = [_Backend.TRITON_ATTN] if current_platform.is_device_capability((10, 0)): - BACKENDS += [_Backend.FLASHINFER] + BACKENDS_FP8 += [_Backend.FLASHINFER] + BACKENDS_FP4 += [_Backend.FLASHINFER] MODELS_FP4 += [ ( "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", @@ -288,13 +292,12 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize( - "model_name, model_class, custom_ops", + "backend, model, custom_ops", # Test attention+quant_fp8 fusion with custom and torch impls - [(*model, c) for model in MODELS_FP8 for c in ["+quant_fp8", "-quant_fp8"]] + list(itertools.product(BACKENDS_FP8, MODELS_FP8, ["+quant_fp8", "-quant_fp8"])) # quant_fp4 only has the custom impl - + [(*model, c) for model in MODELS_FP4 for c in [""]], + + list(itertools.product(BACKENDS_FP4, MODELS_FP4, [""])), ) -@pytest.mark.parametrize("backend", BACKENDS) @pytest.mark.parametrize("use_inductor_graph_partition", USE_INDUCTOR_GRAPH_PARTITION) @pytest.mark.skipif( not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA" @@ -307,8 +310,7 @@ def test_attention_quant_pattern( batch_size: int, dtype: torch.dtype, custom_ops: str, - model_name: str, - model_class: type[AttentionQuantPatternModel], + model: tuple[str, type[AttentionQuantPatternModel]], backend: _Backend, use_inductor_graph_partition: bool, dist_init, @@ -317,6 +319,7 @@ def test_attention_quant_pattern( """Test AttentionStaticQuantPattern fusion pass""" custom_ops_list = custom_ops.split(",") if custom_ops else [] + model_name, model_class = model device = torch.device("cuda:0") torch.manual_seed(42) From cdd1529b0899cb18495ce46b8b27a0e5a0db5719 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 25 Sep 2025 17:18:43 -0700 Subject: [PATCH 015/137] Flat product for better test names/visibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_attn.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 4d6cdabf6a90..7d672bc343b4 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy import itertools +from collections.abc import Iterable +from typing import Any import pytest import torch._dynamo @@ -285,6 +287,13 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): USE_INDUCTOR_GRAPH_PARTITION = [False] +def flat_product(*iterables: Iterable[Any]): + """Flatten lists of tuples into cartesian product.""" + for element in itertools.product(*iterables): + normalized = (e if isinstance(e, tuple) else [e] for e in element) + yield list(itertools.chain(*normalized)) + + @pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS) @pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize( @@ -292,11 +301,11 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize( - "backend, model, custom_ops", - # Test attention+quant_fp8 fusion with custom and torch impls - list(itertools.product(BACKENDS_FP8, MODELS_FP8, ["+quant_fp8", "-quant_fp8"])) + "backend, model_name, model_class, custom_ops", + # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8 + list(flat_product(BACKENDS_FP8, MODELS_FP8, ["+quant_fp8", "-quant_fp8"])) # quant_fp4 only has the custom impl - + list(itertools.product(BACKENDS_FP4, MODELS_FP4, [""])), + + list(flat_product(BACKENDS_FP4, MODELS_FP4, [""])), ) @pytest.mark.parametrize("use_inductor_graph_partition", USE_INDUCTOR_GRAPH_PARTITION) @pytest.mark.skipif( @@ -310,7 +319,8 @@ def test_attention_quant_pattern( batch_size: int, dtype: torch.dtype, custom_ops: str, - model: tuple[str, type[AttentionQuantPatternModel]], + model_name: str, + model_class: type[AttentionQuantPatternModel], backend: _Backend, use_inductor_graph_partition: bool, dist_init, @@ -319,7 +329,6 @@ def test_attention_quant_pattern( """Test AttentionStaticQuantPattern fusion pass""" custom_ops_list = custom_ops.split(",") if custom_ops else [] - model_name, model_class = model device = torch.device("cuda:0") torch.manual_seed(42) From 141a37eb431da104f4173ea7d1c0c3895354020d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 26 Sep 2025 07:41:41 -0700 Subject: [PATCH 016/137] Fix rmsnorm MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/model_executor/layers/layernorm.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 976b2e852265..7e15efab379b 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -170,11 +170,9 @@ def __init__( self.variance_size_override = ( None if var_hidden_size == hidden_size else var_hidden_size ) - self.weight = None - if has_weight: - dtype = dtype or torch.get_default_dtype() - self.weight = nn.Parameter(torch.ones(hidden_size, dtype=dtype)) - weight_dtype = self.weight.data.dtype + weight_dtype = dtype or torch.get_default_dtype() + self.has_weight = has_weight + self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype)) if current_platform.is_rocm(): self.rocm_norm_func = dispatch_rocm_rmsnorm_func( @@ -233,11 +231,12 @@ def forward_native( residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" + return self.forward_static( x, self.variance_epsilon, self.hidden_size, - self.weight.data, + self.weight.data if self.has_weight else None, residual, self.variance_size_override, ) From c6d6c3ba7f35105ed5809ac553012db7c1677746 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 26 Sep 2025 13:20:52 -0700 Subject: [PATCH 017/137] Refactor E2E attn fusion test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_full_graph.py | 89 ++++++++++++++++++++------------ 1 file changed, 55 insertions(+), 34 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 8ccae4cfb9df..dffa221a9f7f 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -24,23 +24,30 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None): TEST_MODELS: list[tuple[str, dict[str, Any]]] = [ ("facebook/opt-125m", {}), - ( - "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", - { - "dtype": torch.float16, - }, - ), ( "neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", { "dtype": torch.float16, }, ), - ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}), ("meta-llama/Llama-3.2-1B-Instruct", {}), ] if all: + if not current_platform.has_device_capability((10, 0)): + # int8 removed on Blackwell + TEST_MODELS.extend( + [ + ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}), + ( + "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", + { + "dtype": torch.float16, + }, + ), + ] + ) + # TODO: figure out why this fails. if False and is_quant_method_supported("gguf"): # noqa: SIM223 TEST_MODELS.append( @@ -85,15 +92,14 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None): "optimization_level", [CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE], ) -@pytest.mark.parametrize("model_info", models_list(all=True)) +@pytest.mark.parametrize("model, model_kwargs", models_list(all=True)) @create_new_process_for_each_test() def test_full_graph( monkeypatch: pytest.MonkeyPatch, - model_info: tuple[str, dict[str, Any]], + model: str, + model_kwargs: dict[str, Any], optimization_level: int, ): - model, model_kwargs = model_info - with monkeypatch.context(): print(f"MODEL={model}") @@ -180,40 +186,55 @@ def test_fp8_kv_scale_compile(optimization_level: int): run_model(optimization_level, model, model_kwargs) -def test_inductor_graph_partition_attn_fusion(caplog_vllm): - if not is_torch_equal_or_newer("2.9.0.dev"): - pytest.skip("inductor graph partition is only available in PyTorch 2.9+") +INDUCTOR_GRAPH_PARTITION = ( + [False, True] if (is_torch_equal_or_newer("2.9.0.dev")) else [False] +) + +@pytest.mark.parametrize("custom_ops", ["+quant_fp8", "-quant_fp8"]) +@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +def test_default_fusion( + custom_ops: str, inductor_graph_partition: bool, caplog_vllm, monkeypatch +): model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" + model_kwargs = {"kv_cache_dtype": "fp8", "max_model_len": 1024} + backend = _Backend.FLASHINFER + + custom_ops_list = custom_ops.split(",") if custom_ops else [] + + if inductor_graph_partition: + mode = CUDAGraphMode.FULL_AND_PIECEWISE + splitting_ops: Optional[list[str]] = None + else: + mode = CUDAGraphMode.FULL_DECODE_ONLY + splitting_ops = [] + + # Disable, compile cache to make sure custom passes run. + # Otherwise, we can't verify fusion happened through the logs. + # Log capture also doesn't work with multiprocessing yet. + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + compilation_config = CompilationConfig( + # Testing properties + custom_ops=custom_ops_list, + use_inductor_graph_partition=inductor_graph_partition, + cudagraph_mode=mode, + splitting_ops=splitting_ops, + # Common level=CompilationLevel.PIECEWISE, - use_inductor_graph_partition=True, - cudagraph_mode=CUDAGraphMode.PIECEWISE, - custom_ops=["+quant_fp8"], pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True), + # Inductor caches custom passes by default as well via uuid + inductor_compile_config={"force_disable_caches": True}, ) - model_kwargs = { - "kv_cache_dtype": "fp8", - "max_model_len": 1024, - } + with ( caplog_vllm.at_level(logging.DEBUG), - global_force_attn_backend_context_manager(_Backend.FLASHINFER), + global_force_attn_backend_context_manager(backend), ): run_model(compilation_config, model, model_kwargs) - try: - assert "Fused quantization onto 48 attention nodes" in caplog_vllm.text, ( - caplog_vllm.text - ) - except AssertionError: - # Note: this message is only triggered when the compilation goes - # through the custom pass. Due to multiple layers of cache on - # PyTorch side, the compilation of a graph may be cached such - # that custom pass directly goes through cache. In this case, - # we go through this branch and assert that the pass is not - # triggered. - assert "Fused quantization" not in caplog_vllm.text + assert "Fused quant onto 48 attention nodes" in caplog_vllm.text, caplog_vllm.text def run_model( From 490ac8610d9e8876dd79619b9b8f72332e5694d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 26 Sep 2025 13:24:01 -0700 Subject: [PATCH 018/137] Add TP=2 test (untested) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_full_graph.py | 57 +++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index dffa221a9f7f..99b072cfd30f 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -18,7 +18,7 @@ from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer -from ..utils import create_new_process_for_each_test +from ..utils import create_new_process_for_each_test, multi_gpu_test def models_list(*, all: bool = True, keywords: list[str] | None = None): @@ -237,6 +237,61 @@ def test_default_fusion( assert "Fused quant onto 48 attention nodes" in caplog_vllm.text, caplog_vllm.text +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("custom_ops", ["+quant_fp8", "-quant_fp8"]) +@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +def test_default_fusion_tp2( + custom_ops: str, inductor_graph_partition: bool, caplog_vllm, monkeypatch +): + model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" + model_kwargs = {"kv_cache_dtype": "fp8", "max_model_len": 1024} + backend = _Backend.FLASHINFER + + custom_ops_list = custom_ops.split(",") if custom_ops else [] + + if inductor_graph_partition: + mode = CUDAGraphMode.FULL_AND_PIECEWISE + splitting_ops: Optional[list[str]] = None + else: + mode = CUDAGraphMode.FULL_DECODE_ONLY + splitting_ops = [] + + # Disable, compile cache to make sure custom passes run. + # Otherwise, we can't verify fusion happened through the logs. + # Log capture also doesn't work with multiprocessing yet. + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + + model_kwargs["tensor_parallel_size"] = 2 + compilation_config = CompilationConfig( + # Testing properties + use_inductor_graph_partition=inductor_graph_partition, + cudagraph_mode=mode, + custom_ops=custom_ops_list, + splitting_ops=splitting_ops, + # Common + level=CompilationLevel.PIECEWISE, + pass_config=PassConfig( + enable_attn_fusion=True, + enable_noop=True, + enable_fi_allreduce_fusion=True, + ), + # Inductor caches custom passes by default as well via uuid + inductor_compile_config={"force_disable_caches": True}, + ) + + with ( + caplog_vllm.at_level(logging.DEBUG), + global_force_attn_backend_context_manager(backend), + ): + run_model(compilation_config, model, model_kwargs) + + assert "Fused quant onto 48 attention nodes" in caplog_vllm.text, caplog_vllm.text + + # TODO fill in correct number + assert "Replaced 5 patterns" in caplog_vllm.text, caplog_vllm.text + + def run_model( compile_config: Union[int, CompilationConfig], model: str, From d0b1b563b4118afe73b3dfb7431359ef82de1830 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 26 Sep 2025 15:39:08 -0700 Subject: [PATCH 019/137] improve tests by adding more cases MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_full_graph.py | 151 ++++++++++++++++++++++++------ tests/compile/test_fusion_attn.py | 11 +-- tests/utils.py | 9 ++ 3 files changed, 131 insertions(+), 40 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 99b072cfd30f..2f18488424e7 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -3,9 +3,11 @@ from __future__ import annotations +import itertools import logging import tempfile -from typing import Any, Union +from collections.abc import Iterable +from typing import Any, Optional, Union import pytest import torch @@ -18,7 +20,7 @@ from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer -from ..utils import create_new_process_for_each_test, multi_gpu_test +from ..utils import create_new_process_for_each_test, flat_product, multi_gpu_test def models_list(*, all: bool = True, keywords: list[str] | None = None): @@ -103,7 +105,7 @@ def test_full_graph( with monkeypatch.context(): print(f"MODEL={model}") - run_model(optimization_level, model, model_kwargs) + run_model(optimization_level, model, **model_kwargs) # TODO(luka) add other supported compilation config scenarios here @@ -168,7 +170,49 @@ def test_custom_compile_config( model, model_kwargs = model_info print(f"MODEL={model}") - run_model(compilation_config, model, model_kwargs) + run_model(compilation_config, model, **model_kwargs) + + +MODELS_FP8: list[tuple[str, dict[str, Any], _Backend]] = [] +MODELS_FP4: list[tuple[str, dict[str, Any], _Backend]] = [] +MODELS: list[tuple[str, dict[str, Any], _Backend]] = [] # tp-only + +if current_platform.is_cuda(): + MODELS_FP8 += [ + ( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + {"max_model_len": 1024}, + _Backend.TRITON_ATTN, + ) + ] + + if current_platform.is_device_capability((10, 0)): + MODELS_FP8 += [ + ( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + {"kv_cache_dtype": "fp8", "max_model_len": 1024}, + _Backend.FLASHINFER, + ) + ] + + MODELS_FP4 += [ + ( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + {"kv_cache_dtype": "fp8", "max_model_len": 1024}, + _Backend.FLASHINFER, + ) + ] + + MODELS += [ + ( + "meta-llama/Llama-3.1-8B-Instruct", + {"max_model_len": 1024}, + _Backend.FLASHINFER, + ) + ] + +elif current_platform.is_rocm(): + MODELS_FP8 += [("amd/Llama-3.1-8B-Instruct-FP8-KV", {}, _Backend.TRITON_ATTN)] @pytest.mark.parametrize( @@ -183,23 +227,34 @@ def test_fp8_kv_scale_compile(optimization_level: int): "calculate_kv_scales": True, "max_model_len": 512, } - run_model(optimization_level, model, model_kwargs) + run_model(optimization_level, model, **model_kwargs) INDUCTOR_GRAPH_PARTITION = ( [False, True] if (is_torch_equal_or_newer("2.9.0.dev")) else [False] ) +# TODO(luka) test both in nightly +CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"] + -@pytest.mark.parametrize("custom_ops", ["+quant_fp8", "-quant_fp8"]) +@pytest.mark.parametrize( + "model_name, model_kwargs, backend, custom_ops", + # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8 + list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8)) + # quant_fp4 only has the custom impl + + list(flat_product(MODELS_FP4, [""])), +) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) -def test_default_fusion( - custom_ops: str, inductor_graph_partition: bool, caplog_vllm, monkeypatch +def test_e2e_fusion_attn_quant( + model_name: str, + model_kwargs: dict[str, Any], + backend: _Backend, + custom_ops: str, + inductor_graph_partition: bool, + caplog_vllm, + monkeypatch, ): - model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" - model_kwargs = {"kv_cache_dtype": "fp8", "max_model_len": 1024} - backend = _Backend.FLASHINFER - custom_ops_list = custom_ops.split(",") if custom_ops else [] if inductor_graph_partition: @@ -232,21 +287,47 @@ def test_default_fusion( caplog_vllm.at_level(logging.DEBUG), global_force_attn_backend_context_manager(backend), ): - run_model(compilation_config, model, model_kwargs) + run_model(compilation_config, model_name, **model_kwargs) assert "Fused quant onto 48 attention nodes" in caplog_vllm.text, caplog_vllm.text +# TODO(luka) test both in nightly +CUSTOM_OPS_RMS_NORM = ["-rms_norm"] # , "+rms_norm"] + + +def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: + for op_list in itertools.product(*custom_ops_lists): + yield ",".join(op_list) + + @multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("custom_ops", ["+quant_fp8", "-quant_fp8"]) +@pytest.mark.parametrize( + "model_name, model_kwargs, backend, custom_ops", + # Toggle RMSNorm and QuantFP8 for FP8 models + list( + flat_product( + MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM) + ) + ) + # Toggle RMSNorm for FP4 models and unquant models + + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), +) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) -def test_default_fusion_tp2( - custom_ops: str, inductor_graph_partition: bool, caplog_vllm, monkeypatch +@pytest.mark.skipif( + not current_platform.is_cuda() + or not current_platform.has_device_capability((10, 0)), + reason="allreduce+rmsnorm fusion only supported on blackwell", +) +def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm( + model_name, + model_kwargs, + backend, + custom_ops: str, + inductor_graph_partition: bool, + caplog_vllm, + monkeypatch, ): - model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" - model_kwargs = {"kv_cache_dtype": "fp8", "max_model_len": 1024} - backend = _Backend.FLASHINFER - custom_ops_list = custom_ops.split(",") if custom_ops else [] if inductor_graph_partition: @@ -262,7 +343,6 @@ def test_default_fusion_tp2( monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - model_kwargs["tensor_parallel_size"] = 2 compilation_config = CompilationConfig( # Testing properties use_inductor_graph_partition=inductor_graph_partition, @@ -284,19 +364,25 @@ def test_default_fusion_tp2( caplog_vllm.at_level(logging.DEBUG), global_force_attn_backend_context_manager(backend), ): - run_model(compilation_config, model, model_kwargs) + run_model( + compilation_config, model_name, tensor_parallel_size=2, **model_kwargs + ) assert "Fused quant onto 48 attention nodes" in caplog_vllm.text, caplog_vllm.text # TODO fill in correct number - assert "Replaced 5 patterns" in caplog_vllm.text, caplog_vllm.text + assert "Replaced 96 patterns" in caplog_vllm.text, caplog_vllm.text def run_model( - compile_config: Union[int, CompilationConfig], - model: str, - model_kwargs: dict[str, Any], + compile_config: Union[int, CompilationConfig], model: str, **model_kwargs ): + compilation_config = ( + compile_config + if isinstance(compile_config, CompilationConfig) + else CompilationConfig(level=compile_config) + ) + prompts = [ "Hello, my name is", "The president of the United States is", @@ -304,12 +390,17 @@ def run_model( "The future of AI is", ] sampling_params = SamplingParams(temperature=0) + # Allow override from model_kwargs + model_kwargs = {"tensor_parallel_size": 1, **model_kwargs} + model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs} + + # No cudagraphs by default + if compilation_config.cudagraph_mode is None: + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + llm = LLM( model=model, - enforce_eager=True, - tensor_parallel_size=1, - disable_custom_all_reduce=True, - compilation_config=compile_config, + compilation_config=compilation_config, **model_kwargs, ) outputs = llm.generate(prompts, sampling_params) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 7d672bc343b4..b52b573ec7e5 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -1,14 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy -import itertools -from collections.abc import Iterable -from typing import Any import pytest import torch._dynamo from tests.compile.backend import LazyInitPass, TestBackend +from tests.utils import flat_product from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.attention import Attention, AttentionMetadata @@ -287,13 +285,6 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): USE_INDUCTOR_GRAPH_PARTITION = [False] -def flat_product(*iterables: Iterable[Any]): - """Flatten lists of tuples into cartesian product.""" - for element in itertools.product(*iterables): - normalized = (e if isinstance(e, tuple) else [e] for e in element) - yield list(itertools.chain(*normalized)) - - @pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS) @pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize( diff --git a/tests/utils.py b/tests/utils.py index b853542c241f..16ef6458cf50 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,6 +6,7 @@ import copy import functools import importlib +import itertools import json import os import random @@ -15,6 +16,7 @@ import tempfile import time import warnings +from collections.abc import Iterable from contextlib import ExitStack, contextmanager, suppress from multiprocessing import Process from pathlib import Path @@ -1260,3 +1262,10 @@ def check_answers( frac_ok = numok / len(answer) print(f"Num OK: {numok}/{len(answer)} {frac_ok}") assert frac_ok >= accept_rate + + +def flat_product(*iterables: Iterable[Any]): + """Flatten lists of tuples into cartesian product.""" + for element in itertools.product(*iterables): + normalized = (e if isinstance(e, tuple) else [e] for e in element) + yield list(itertools.chain(*normalized)) From 47b4688d1cdad8701e2aad381fd2568fc1bce78e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Sat, 27 Sep 2025 07:38:52 -0700 Subject: [PATCH 020/137] TEMP working on caplog MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_full_graph.py | 8 ++++++-- tests/conftest.py | 22 ++++++++++++++++++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 2f18488424e7..b282e234572f 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -253,6 +253,7 @@ def test_e2e_fusion_attn_quant( custom_ops: str, inductor_graph_partition: bool, caplog_vllm, + caplog_mp_workaround, monkeypatch, ): custom_ops_list = custom_ops.split(",") if custom_ops else [] @@ -268,7 +269,7 @@ def test_e2e_fusion_attn_quant( # Otherwise, we can't verify fusion happened through the logs. # Log capture also doesn't work with multiprocessing yet. monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") - monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + # monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") compilation_config = CompilationConfig( # Testing properties @@ -285,6 +286,7 @@ def test_e2e_fusion_attn_quant( with ( caplog_vllm.at_level(logging.DEBUG), + caplog_mp_workaround(), global_force_attn_backend_context_manager(backend), ): run_model(compilation_config, model_name, **model_kwargs) @@ -319,6 +321,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: or not current_platform.has_device_capability((10, 0)), reason="allreduce+rmsnorm fusion only supported on blackwell", ) +@pytest.mark.skip(reason="Still no solution for capturing logs from subprocess") def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm( model_name, model_kwargs, @@ -341,7 +344,8 @@ def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm( # Otherwise, we can't verify fusion happened through the logs. # Log capture also doesn't work with multiprocessing yet. monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") - monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + # TODO + # monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") compilation_config = CompilationConfig( # Testing properties diff --git a/tests/conftest.py b/tests/conftest.py index 4713e1238596..b2fa96f48e8c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# ruff: noqa +import contextlib from tblib import pickling_support +# ruff: noqa + # Install support for pickling exceptions so that we can nicely propagate # failures from tests running in a subprocess. # This should be run before any custom exception subclasses are defined. @@ -1067,6 +1068,23 @@ def caplog_vllm(temporary_enable_log_propagate, caplog): yield caplog +@pytest.fixture() +def caplog_mp_workaround(): + @contextlib.contextmanager + def ctx(): + import logging.handlers + import multiprocessing as mp + + logger_queue: mp.Queue[logging.LogRecord] = mp.Queue() + logger = logging.getLogger() + logger.addHandler(logging.handlers.QueueHandler(logger_queue)) + yield + while not logger_queue.empty(): + logger.handle(logger_queue.get()) + + return ctx + + @pytest.fixture(scope="session") def num_gpus_available(): """Get number of GPUs without initializing the CUDA context From ae7f56f042876122127b04e199162b985334b747 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 30 Sep 2025 12:50:28 -0700 Subject: [PATCH 021/137] Temp MP workaround P2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/conftest.py | 89 +++++++++++++++++++++++++++++++++++++++++--- tests/test_logger.py | 17 +++++++++ 2 files changed, 101 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index b2fa96f48e8c..bbfe3eeac8f7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +import pathlib +from copy import deepcopy from tblib import pickling_support @@ -41,7 +43,7 @@ from transformers.models.auto.auto_factory import _BaseAutoModelClass from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs -from vllm import LLM, SamplingParams +from vllm import LLM, SamplingParams, envs from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset @@ -1068,8 +1070,25 @@ def caplog_vllm(temporary_enable_log_propagate, caplog): yield caplog -@pytest.fixture() -def caplog_mp_workaround(): +@pytest.fixture(scope="session") +def caplog_mp_fork(): + """ + This fixture enables capturing logs from a forked MP subprocess. + It should be used in conjunction with caplog_vllm. + + By default, subprocess logs do not go through the parent process. + We instead create a queue listener in the parent process which + forwards logs to the logger's other handlers, and add a QueueHandler + to the root logger. Forked subprocesses will inherit the root logger + and pass their messages to the queue, which the listener will forward + to the root logger, which can be captured by caplog. + + Note that this workaround only works for fork; with spawn, the subprocess + reinitializes logging and does not automatically inherit the queue. + We'd have to manually pass the queue to the subprocess at the spawn point. + See caplog_mp_spawn below. + """ + @contextlib.contextmanager def ctx(): import logging.handlers @@ -1077,10 +1096,70 @@ def ctx(): logger_queue: mp.Queue[logging.LogRecord] = mp.Queue() logger = logging.getLogger() + handlers = logger.handlers + + # The listener works on a background thread, not inherited by the child. + queue_listener = logging.handlers.QueueListener(logger_queue, *handlers) + queue_listener.start() + + # Add queue handler after creating the listener to avoid cycle logger.addHandler(logging.handlers.QueueHandler(logger_queue)) yield - while not logger_queue.empty(): - logger.handle(logger_queue.get()) + queue_listener.stop() + + return ctx + + +class LogHolder: + def __init__(self): + self.text = None + + +@pytest.fixture(scope="session") +def caplog_mp_spawn(tmp_path, monkeypatch): + """ + This fixture enables capturing logs from a forked MP subprocess. + It does not require caplog_vllm (but it only contains log + + By default, subprocess logs do not go through the parent process. + We instead add a FileHandler to the config so the spawned child process + writes its logs to a temp file and then return the contents. + + Note: this method could be extended to fork by either reconfiguring logging + in the parent or using a SocketHandler: + https://docs.python.org/3/howto/logging-cookbook.html#sending-and-receiving-logging-events-across-a-network # noqa: E501 + """ + + @contextlib.contextmanager + def ctx(level: int | str): + from vllm.logger import DEFAULT_LOGGING_CONFIG + + config_path = tmp_path / "vllm_logging_config.json" + log_path = tmp_path / "vllm.log" + log_holder = LogHolder() + + config = deepcopy(DEFAULT_LOGGING_CONFIG) + if envs.VLLM_LOGGING_CONFIG_PATH: + path = pathlib.Path(envs.VLLM_LOGGING_CONFIG_PATH) + assert path.exists() + config = json.loads(path.read_text()) + + config["loggers"]["vllm"]["handlers"] += ["vllm_file"] + config["handlers"]["vllm_file"] = { + "class": "logging.FileHandler", + "formatter": "vllm", + "level": level, + "filename": log_path.as_posix(), + } + + config_path.write_text(json.dumps(config)) + + with monkeypatch.context() as monkeypatch_ctx: + monkeypatch_ctx.setenv("VLLM_LOGGING_CONFIG_PATH", config_path.as_posix()) + monkeypatch_ctx.setenv("VLLM_CONFIGURE_LOGGING", "1") + yield log_holder + + log_holder.text = log_path.read_text() return ctx diff --git a/tests/test_logger.py b/tests/test_logger.py index ec368d4897b5..af006f1456b8 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -501,3 +501,20 @@ def test_streaming_complete_logs_full_text_content(): assert call_args[1] == "test-streaming-full-text" assert call_args[2] == " (streaming complete)" assert call_args[5] == "streaming_complete" + + +test_logger = init_logger("vllm.test_logger") +# https://docs.python.org/3/howto/logging-cookbook.html#sending-and-receiving-logging-events-across-a-network + + +def mp_function(**kwargs): + # This function runs in a subprocess + + test_logger.warning("This is a subprocess: %s", kwargs.get("a")) + test_logger.error("This is a subprocess error.") + test_logger.debug("This is a subprocess debug message: %s.", kwargs.get("b")) + + +def test_caplog_mp_fork(caplog_vllm, caplog_mp_fork): + pass + # TODO From eb899a4d34d2b7ca6d77b36decbd46b3eb873e13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 30 Sep 2025 12:55:33 -0700 Subject: [PATCH 022/137] Temp MP workaround P3 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/conftest.py | 3 ++- tests/test_logger.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index bbfe3eeac8f7..df34924a6f70 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1123,7 +1123,8 @@ def caplog_mp_spawn(tmp_path, monkeypatch): By default, subprocess logs do not go through the parent process. We instead add a FileHandler to the config so the spawned child process - writes its logs to a temp file and then return the contents. + writes its logs to a temp file. + In the parent, we read the file and return the contents. Note: this method could be extended to fork by either reconfiguring logging in the parent or using a SocketHandler: diff --git a/tests/test_logger.py b/tests/test_logger.py index af006f1456b8..22e084991343 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -504,7 +504,6 @@ def test_streaming_complete_logs_full_text_content(): test_logger = init_logger("vllm.test_logger") -# https://docs.python.org/3/howto/logging-cookbook.html#sending-and-receiving-logging-events-across-a-network def mp_function(**kwargs): From a2aa9787df6ed5be6fd9f6010e0e330702052036 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 1 Oct 2025 11:21:02 -0700 Subject: [PATCH 023/137] Test for caplog utils MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/conftest.py | 4 ++-- tests/test_logger.py | 33 +++++++++++++++++++++++++++++++-- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index df34924a6f70..7a907b2ac79f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1070,7 +1070,7 @@ def caplog_vllm(temporary_enable_log_propagate, caplog): yield caplog -@pytest.fixture(scope="session") +@pytest.fixture() def caplog_mp_fork(): """ This fixture enables capturing logs from a forked MP subprocess. @@ -1115,7 +1115,7 @@ def __init__(self): self.text = None -@pytest.fixture(scope="session") +@pytest.fixture() def caplog_mp_spawn(tmp_path, monkeypatch): """ This fixture enables capturing logs from a forked MP subprocess. diff --git a/tests/test_logger.py b/tests/test_logger.py index 22e084991343..f1c31c245475 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -515,5 +515,34 @@ def mp_function(**kwargs): def test_caplog_mp_fork(caplog_vllm, caplog_mp_fork): - pass - # TODO + with caplog_vllm.at_level(logging.DEBUG), caplog_mp_fork(): + import multiprocessing + + ctx = multiprocessing.get_context("fork") + p = ctx.Process( + target=mp_function, + name=f"SubProcess{1}", + kwargs={"a": "AAAA", "b": "BBBBB"}, + ) + p.start() + p.join() + + assert "AAAA" in caplog_vllm.text + assert "BBBBB" in caplog_vllm.text + + +def test_caplog_mp_spawn(caplog_mp_spawn): + with caplog_mp_spawn(logging.DEBUG) as log_holder: + import multiprocessing + + ctx = multiprocessing.get_context("spawn") + p = ctx.Process( + target=mp_function, + name=f"SubProcess{1}", + kwargs={"a": "AAAA", "b": "BBBBB"}, + ) + p.start() + p.join() + + assert "AAAA" in log_holder.text + assert "BBBBB" in log_holder.text From 21a9f9f42b21f46190838a0336438fe5c091e728 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 1 Oct 2025 19:02:24 -0700 Subject: [PATCH 024/137] Fixed tests, passing with 2.8, 2.9 tbd MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_full_graph.py | 48 +++++++++++++++++--------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index b282e234572f..b6f7aba6821c 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -10,12 +10,12 @@ from typing import Any, Optional, Union import pytest +import regex as re import torch from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams from vllm.attention.backends.registry import _Backend -from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer @@ -235,7 +235,8 @@ def test_fp8_kv_scale_compile(optimization_level: int): ) # TODO(luka) test both in nightly -CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"] +# TODO(luka) change to - +CUSTOM_OPS_FP8 = ["+quant_fp8"] # , "+quant_fp8"] @pytest.mark.parametrize( @@ -252,8 +253,7 @@ def test_e2e_fusion_attn_quant( backend: _Backend, custom_ops: str, inductor_graph_partition: bool, - caplog_vllm, - caplog_mp_workaround, + caplog_mp_spawn, monkeypatch, ): custom_ops_list = custom_ops.split(",") if custom_ops else [] @@ -269,7 +269,11 @@ def test_e2e_fusion_attn_quant( # Otherwise, we can't verify fusion happened through the logs. # Log capture also doesn't work with multiprocessing yet. monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") - # monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + + # To capture subprocess logs, we need to know whether spawn or fork is used. + # Force spawn as it is more general. + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) compilation_config = CompilationConfig( # Testing properties @@ -284,18 +288,15 @@ def test_e2e_fusion_attn_quant( inductor_compile_config={"force_disable_caches": True}, ) - with ( - caplog_vllm.at_level(logging.DEBUG), - caplog_mp_workaround(), - global_force_attn_backend_context_manager(backend), - ): + with caplog_mp_spawn(logging.DEBUG) as log_holder: run_model(compilation_config, model_name, **model_kwargs) - assert "Fused quant onto 48 attention nodes" in caplog_vllm.text, caplog_vllm.text + assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text # TODO(luka) test both in nightly -CUSTOM_OPS_RMS_NORM = ["-rms_norm"] # , "+rms_norm"] +# TODO(luka) change to - +CUSTOM_OPS_RMS_NORM = ["+rms_norm"] # , "+rms_norm"] def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: @@ -321,14 +322,13 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: or not current_platform.has_device_capability((10, 0)), reason="allreduce+rmsnorm fusion only supported on blackwell", ) -@pytest.mark.skip(reason="Still no solution for capturing logs from subprocess") def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm( model_name, model_kwargs, backend, custom_ops: str, inductor_graph_partition: bool, - caplog_vllm, + caplog_mp_spawn, monkeypatch, ): custom_ops_list = custom_ops.split(",") if custom_ops else [] @@ -344,8 +344,11 @@ def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm( # Otherwise, we can't verify fusion happened through the logs. # Log capture also doesn't work with multiprocessing yet. monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") - # TODO - # monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + + # To capture subprocess logs, we need to know whether spawn or fork is used. + # Force spawn as it is more general. + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) compilation_config = CompilationConfig( # Testing properties @@ -364,18 +367,17 @@ def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm( inductor_compile_config={"force_disable_caches": True}, ) - with ( - caplog_vllm.at_level(logging.DEBUG), - global_force_attn_backend_context_manager(backend), - ): + with caplog_mp_spawn(logging.DEBUG) as log_holder: run_model( compilation_config, model_name, tensor_parallel_size=2, **model_kwargs ) - assert "Fused quant onto 48 attention nodes" in caplog_vllm.text, caplog_vllm.text + assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text - # TODO fill in correct number - assert "Replaced 96 patterns" in caplog_vllm.text, caplog_vllm.text + matches = re.findall( + r"\[collective_fusion.py:\d+] Replaced 96 patterns", log_holder.text + ) + assert len(matches) == 2, log_holder.text def run_model( From 66a35a90b724f53037395bc96fcef87ea8b3b172 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 2 Oct 2025 19:26:42 -0400 Subject: [PATCH 025/137] Update tests/compile/backend.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/backend.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/compile/backend.py b/tests/compile/backend.py index fb92fd7b42a5..ac62040287d2 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -4,7 +4,6 @@ import weakref from collections.abc import Sequence from copy import deepcopy -from pathlib import Path from typing import Callable, Union import depyf @@ -55,12 +54,8 @@ def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]): self.inductor_config["force_disable_caches"] = True self.inductor_config["post_grad_custom_post_pass"] = self.post_pass - if compile_config.debug_dump_path: - self.debug_dump_path = ( - Path(compile_config.debug_dump_path) - / f"rank_{vllm_config.parallel_config.rank}" - ) - self.ctx = depyf.prepare_debug(str(self.debug_dump_path)) + if debug_dump_path := vllm_config.compile_debug_dump_path(): + self.ctx = depyf.prepare_debug(debug_dump_path.as_posix()) self.ctx.__enter__() else: self.ctx = None From 7eb1364457d3ebc65d0a156a4a48ccef122bba5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 2 Oct 2025 19:26:48 -0400 Subject: [PATCH 026/137] Update csrc/layernorm_kernels.cu MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- csrc/layernorm_kernels.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index b738cdbbdc53..b037531cceb5 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -380,6 +380,7 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] double epsilon) { + TORCH_CHECK(weight.scalar_type() == input.scalar_type()); TORCH_CHECK(input.scalar_type() == residual.scalar_type()); TORCH_CHECK(residual.is_contiguous()); TORCH_CHECK(weight.is_contiguous()); From 5fef1804edebe5d0bb441f0e7620f691164707d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 2 Oct 2025 16:35:31 -0700 Subject: [PATCH 027/137] clean up fullgraph tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_full_graph.py | 43 ++++++++++++++------------------ 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index b6f7aba6821c..d5d22844a223 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -173,6 +173,21 @@ def test_custom_compile_config( run_model(compilation_config, model, **model_kwargs) +@pytest.mark.parametrize( + "optimization_level", + [CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE], +) +def test_fp8_kv_scale_compile(optimization_level: int): + model = "Qwen/Qwen2-0.5B" + model_kwargs = { + "quantization": "fp8", + "kv_cache_dtype": "fp8_e4m3", + "calculate_kv_scales": True, + "max_model_len": 512, + } + run_model(optimization_level, model, **model_kwargs) + + MODELS_FP8: list[tuple[str, dict[str, Any], _Backend]] = [] MODELS_FP4: list[tuple[str, dict[str, Any], _Backend]] = [] MODELS: list[tuple[str, dict[str, Any], _Backend]] = [] # tp-only @@ -214,29 +229,12 @@ def test_custom_compile_config( elif current_platform.is_rocm(): MODELS_FP8 += [("amd/Llama-3.1-8B-Instruct-FP8-KV", {}, _Backend.TRITON_ATTN)] - -@pytest.mark.parametrize( - "optimization_level", - [CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE], -) -def test_fp8_kv_scale_compile(optimization_level: int): - model = "Qwen/Qwen2-0.5B" - model_kwargs = { - "quantization": "fp8", - "kv_cache_dtype": "fp8_e4m3", - "calculate_kv_scales": True, - "max_model_len": 512, - } - run_model(optimization_level, model, **model_kwargs) - - INDUCTOR_GRAPH_PARTITION = ( - [False, True] if (is_torch_equal_or_newer("2.9.0.dev")) else [False] + [True, False] if (is_torch_equal_or_newer("2.9.0.dev")) else [False] ) # TODO(luka) test both in nightly -# TODO(luka) change to - -CUSTOM_OPS_FP8 = ["+quant_fp8"] # , "+quant_fp8"] +CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"] @pytest.mark.parametrize( @@ -308,11 +306,8 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: @pytest.mark.parametrize( "model_name, model_kwargs, backend, custom_ops", # Toggle RMSNorm and QuantFP8 for FP8 models - list( - flat_product( - MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM) - ) - ) + list(flat_product(MODELS_FP8, ["+quant_fp8,+rms_norm"])) + # custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM))) # TODO # Toggle RMSNorm for FP4 models and unquant models + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), ) From db479ae069e833e0c48186bbbb40ef7173c4485d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 2 Oct 2025 16:51:30 -0700 Subject: [PATCH 028/137] TEMP allreduce fusion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_all_reduce.py | 37 +++---- vllm/compilation/collective_fusion.py | 129 +++++++++--------------- 2 files changed, 66 insertions(+), 100 deletions(-) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 7e5c460db174..7d63a380d72c 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -17,6 +17,7 @@ ModelConfig, PassConfig, VllmConfig, + set_current_vllm_config, ) from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( @@ -233,24 +234,26 @@ def all_reduce_fusion_pass_on_test_model( vllm_config.model_config = ModelConfig( model=model_name, trust_remote_code=True, dtype=dtype, seed=42 ) + with set_current_vllm_config(vllm_config): + all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) + noop_pass = NoOpEliminationPass(vllm_config) + func_pass = FixFunctionalizationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + + backend = TestBackend( + all_reduce_fusion_pass, noop_pass, func_pass, cleanup_pass + ) - all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) - noop_pass = NoOpEliminationPass(vllm_config) - func_pass = FixFunctionalizationPass(vllm_config) - cleanup_pass = PostCleanupPass(vllm_config) - - backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass, cleanup_pass) - - token_num = batch_size * seq_len - model = test_model_cls(hidden_size, token_num) + token_num = batch_size * seq_len + model = test_model_cls(hidden_size, token_num) - hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) - residual = torch.randn((token_num, hidden_size), requires_grad=False) + hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) + residual = torch.randn((token_num, hidden_size), requires_grad=False) - compiled_model = torch.compile(model, backend=backend) - compiled_model(hidden_states, residual) + compiled_model = torch.compile(model, backend=backend) + compiled_model(hidden_states, residual) - assert all_reduce_fusion_pass.matched_count == 1 - backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) - backend.check_after_ops(model.ops_in_model_after()) - del all_reduce_fusion_pass + assert all_reduce_fusion_pass.matched_count == 1 + backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) + backend.check_after_ops(model.ops_in_model_after()) + del all_reduce_fusion_pass diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 988a1069cd9e..b41655ffd130 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -18,10 +18,14 @@ get_tensor_model_parallel_world_size, ) from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, +) from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuant, MatcherRMSNorm from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass FP8_DTYPE = current_platform.fp8_dtype() @@ -646,6 +650,19 @@ def get_trtllm_fused_allreduce_kwargs(self): } +class BaseAllReduceRMSNormPattern(BasePattern): + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str, + allreduce_params: FlashInferFusedAllReduceParams, + ): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + + class AllReduceRMSNormPattern(BasePattern): """ This pattern replaces the allreduce + rms norm (without residual) @@ -663,33 +680,24 @@ def __init__( super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) def get_inputs(self): input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) - rms_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) weight = torch.empty([4], device=self.device, dtype=self.dtype) - return [input, rms_result, weight] + return [input, weight] def register(self, pm_pass: PatternMatcherPass): - def pattern( - input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor - ): + def pattern(input: torch.Tensor, weight: torch.Tensor): allreduce_output = tensor_model_parallel_all_reduce(input) - rms = auto_functionalized( - RMS_OP, - result=rms_result, - input=allreduce_output, - weight=weight, - epsilon=self.epsilon, - ) - # rms_result, allreduce_output - return rms[1], allreduce_output + rms = self.rmsnorm_matcher(allreduce_output, weight) - def replacement( - input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor - ): + return rms, allreduce_output + + def replacement(input: torch.Tensor, weight: torch.Tensor): residual = torch.zeros_like(input) + rms_result = torch.empty_like(input) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -727,6 +735,7 @@ def __init__( super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) def get_inputs(self): input = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -741,15 +750,8 @@ def get_inputs(self): def register(self, pm_pass: PatternMatcherPass): def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor): allreduce_output = tensor_model_parallel_all_reduce(input) - rms = auto_functionalized( - RMS_ADD_OP, - input=allreduce_output, - residual=residual, - weight=weight, - epsilon=self.epsilon, - ) - # input, residual - return rms[1], rms[2] + rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) + return allreduce_output, residual def replacement( residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor @@ -793,60 +795,36 @@ def __init__( self.epsilon = epsilon self.allreduce_params = allreduce_params self.quant_dtype = torch.float8_e4m3fn + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) + self.quant_matcher = MatcherQuant(kFp8StaticTensorSym) def register(self, pm_pass: PatternMatcherPass): def get_inputs(): input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype) - rmsnorm_result = torch.empty( - [1, 8, 4], device=self.device, dtype=self.dtype - ) - quant_result = torch.empty( - [1, 8, 4], device=self.device, dtype=self.quant_dtype - ) weight = torch.empty([4], device=self.device, dtype=self.dtype) scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) - return [input, rmsnorm_result, quant_result, weight, scale] + return [input, weight, scale] def pattern( input: torch.Tensor, - rmsnorm_result: torch.Tensor, - quant_result: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): all_reduce = tensor_model_parallel_all_reduce(input) - rmsnorm_out_tuple = auto_functionalized( - RMS_OP, - result=rmsnorm_result, - input=all_reduce, - weight=weight, - epsilon=self.epsilon, - ) - - quant_out_tuple = auto_functionalized( - STATIC_FP8_QUANT_OP, - result=quant_result, - input=rmsnorm_out_tuple[1], - scale=scale, - ) - - # quant_out, allreduce_output - return quant_out_tuple[1], all_reduce + rms = self.rmsnorm_matcher(all_reduce, weight) + quant, _ = self.quant_matcher(rms, scale) + return quant, all_reduce - def replacement( - input: torch.Tensor, - result_rms: torch.Tensor, - quant_result: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): + def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): residual = torch.zeros_like(input) + result_rms = torch.empty_like(input) + result_quant = torch.empty_like(input, dtype=self.quant_dtype) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, residual=residual, norm_out=result_rms, - quant_out=quant_result, + quant_out=result_quant, scale_out=None, rms_gamma=weight, rms_eps=self.epsilon, @@ -886,19 +864,18 @@ def __init__( self.allreduce_params = allreduce_params self.quant_dtype = torch.float8_e4m3fn + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) + self.quant_matcher = MatcherQuant(kFp8StaticTensorSym) + def register(self, pm_pass: PatternMatcherPass): def get_inputs(): input = torch.empty([4, 4], device=self.device, dtype=self.dtype) residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) - quant_result = torch.empty( - [4, 4], device=self.device, dtype=self.quant_dtype - ) scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) return [ - quant_result, residual, input, weight, @@ -906,44 +883,30 @@ def get_inputs(): ] def pattern( - quant_result: torch.Tensor, residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): allreduce_output = tensor_model_parallel_all_reduce(input) + rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual) + quant, _ = self.quant_matcher(rms, scale) - fused_add_rmsnorm_out_tuple = auto_functionalized( - RMS_ADD_OP, - input=allreduce_output, - residual=residual, - weight=weight, - epsilon=self.epsilon, - ) - quant_out_tuple = auto_functionalized( - STATIC_FP8_QUANT_OP, - result=quant_result, - input=fused_add_rmsnorm_out_tuple[1], - scale=scale, - ) - - # quant_out, allreduce_output - return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2] + return quant, allreduce_output def replacement( - quant_result: torch.Tensor, residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): + result_quant = torch.empty_like(input, dtype=self.quant_dtype) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, residual=residual, norm_out=None, - quant_out=quant_result, + quant_out=result_quant, scale_out=None, rms_gamma=weight, rms_eps=self.epsilon, From 54189a9f880335a73c4d5ec8601027c7aeb43bc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 2 Oct 2025 21:24:51 -0400 Subject: [PATCH 029/137] allreduce fusion working (custom ops on) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_all_reduce.py | 8 +++----- vllm/compilation/collective_fusion.py | 4 ++-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 7d63a380d72c..88305c0ed85c 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -84,16 +84,13 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): self.norm = RMSNorm(hidden_size, eps) self.quant_fp8 = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) self.scale = torch.rand(1, dtype=torch.float32) - self.output = torch.empty((token_num, hidden_size), dtype=torch.float32) def forward(self, hidden_states, residual): view = hidden_states.reshape(-1, self.hidden_size) all_reduce = tensor_model_parallel_all_reduce(view) norm_output, residual_output = self.norm(all_reduce, residual) - torch.ops._C.static_scaled_fp8_quant( - self.output, norm_output.contiguous(), self.scale - ) - return self.output, residual_output + quant_out, _ = self.quant_fp8(norm_output, self.scale) + return quant_out, residual_output def ops_in_model_after(self): return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] @@ -227,6 +224,7 @@ def all_reduce_fusion_pass_on_test_model( enable_fi_allreduce_fusion=True, enable_noop=True ) vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) + vllm_config.parallel_config.rank = local_rank # Setup rank for debug path # this is a fake model name to construct the model config # in the vllm_config, it's not really used. diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index b41655ffd130..7d212ef17fb4 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -751,7 +751,7 @@ def register(self, pm_pass: PatternMatcherPass): def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor): allreduce_output = tensor_model_parallel_all_reduce(input) rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) - return allreduce_output, residual + return rms, residual def replacement( residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor @@ -892,7 +892,7 @@ def pattern( rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual) quant, _ = self.quant_matcher(rms, scale) - return quant, allreduce_output + return quant, res def replacement( residual: torch.Tensor, From b7f52bf2fe31b1215b0dc3c81f801944671989f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 2 Oct 2025 22:12:04 -0400 Subject: [PATCH 030/137] allreduce fusion working with/without custom ops (except fp4) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_all_reduce.py | 46 +++++++++++++++++++------ 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 88305c0ed85c..12fa56826840 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -66,8 +66,9 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): def forward(self, hidden_states, residual): view = hidden_states.reshape(-1, self.hidden_size) all_reduce = tensor_model_parallel_all_reduce(view) - norm, _ = self.norm(all_reduce, residual) - return norm + norm, res = self.norm(all_reduce, residual) + + return norm, res def ops_in_model_before(self): return [torch.ops.vllm.all_reduce.default] @@ -98,7 +99,9 @@ def ops_in_model_after(self): def ops_in_model_before(self): return [ torch.ops.vllm.all_reduce.default, - torch.ops._C.static_scaled_fp8_quant.default, + torch.ops._C.static_scaled_fp8_quant.default + if self.quant_fp8.enabled() + else torch.ops.aten.reciprocal.default, ] @@ -139,19 +142,21 @@ def ops_in_model_before(self): @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( - "test_model", + "test_model, enable_quant_fp8", [ - TestAllReduceRMSNormModel, - TestAllReduceFusedAddRMSNormModel, - TestAllReduceFusedAddRMSNormStaticQuantFP8Model, + (TestAllReduceRMSNormModel, False), + (TestAllReduceFusedAddRMSNormModel, False), + (TestAllReduceFusedAddRMSNormStaticQuantFP8Model, True), + (TestAllReduceFusedAddRMSNormStaticQuantFP8Model, False), # TODO: Enable with torch==2.8.0 - # TestAllReduceFusedAddRMSNormStaticQuantFP4Model, + # (TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False), ], ) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [8]) @pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("enable_rms_norm", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @pytest.mark.skipif( not find_spec("flashinfer") @@ -165,6 +170,8 @@ def test_all_reduce_fusion_pass_replace( seq_len: int, hidden_size: int, dtype: torch.dtype, + enable_rms_norm, + enable_quant_fp8, ): num_processes = 2 if ( @@ -179,7 +186,16 @@ def test_all_reduce_fusion_pass_replace( def run_torch_spawn(fn, nprocs): torch.multiprocessing.spawn( fn, - args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype), + args=( + num_processes, + test_model, + batch_size, + seq_len, + hidden_size, + dtype, + enable_rms_norm, + enable_quant_fp8, + ), nprocs=nprocs, ) @@ -194,6 +210,8 @@ def all_reduce_fusion_pass_on_test_model( seq_len: int, hidden_size: int, dtype: torch.dtype, + enable_rms_norm, + enable_quant_fp8, ): current_platform.seed_everything(0) @@ -215,9 +233,15 @@ def all_reduce_fusion_pass_on_test_model( init_distributed_environment() initialize_model_parallel(tensor_model_parallel_size=world_size) + custom_ops = [] + if enable_rms_norm: + custom_ops.append("+rms_norm") + if enable_quant_fp8: + custom_ops.append("+quant_fp8") + vllm_config = VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm", "+quant_fp8"] + level=CompilationLevel.PIECEWISE, custom_ops=custom_ops ) ) vllm_config.compilation_config.pass_config = PassConfig( @@ -239,7 +263,7 @@ def all_reduce_fusion_pass_on_test_model( cleanup_pass = PostCleanupPass(vllm_config) backend = TestBackend( - all_reduce_fusion_pass, noop_pass, func_pass, cleanup_pass + noop_pass, all_reduce_fusion_pass, func_pass, cleanup_pass ) token_num = batch_size * seq_len From d09a278fa869c8e625403cca3e578f6806ca5623 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 2 Oct 2025 22:16:24 -0400 Subject: [PATCH 031/137] allreduce fusion working with/without custom ops (with fp4) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_all_reduce.py | 3 +-- vllm/compilation/collective_fusion.py | 30 ++++++------------------- 2 files changed, 8 insertions(+), 25 deletions(-) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 12fa56826840..657ebc4a28a6 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -148,8 +148,7 @@ def ops_in_model_before(self): (TestAllReduceFusedAddRMSNormModel, False), (TestAllReduceFusedAddRMSNormStaticQuantFP8Model, True), (TestAllReduceFusedAddRMSNormStaticQuantFP8Model, False), - # TODO: Enable with torch==2.8.0 - # (TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False), + (TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False), ], ) @pytest.mark.parametrize("batch_size", [8]) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 7d212ef17fb4..d5a3fcde03b6 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -943,6 +943,7 @@ def __init__( super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) def register(self, pm_pass: PatternMatcherPass): def get_inputs(): @@ -976,18 +977,11 @@ def pattern( output_scale: torch.Tensor, ): all_reduce = tensor_model_parallel_all_reduce(input) - rmsnorm_out_tuple = auto_functionalized( - RMS_OP, - result=rmsnorm_result, - input=all_reduce, - weight=weight, - epsilon=self.epsilon, - ) - + rms = self.rmsnorm_matcher(all_reduce, weight) quant_out_tuple = auto_functionalized( STATIC_FP4_QUANT_OP, output=quant_result, - input=rmsnorm_out_tuple[1], + input=rms, output_scale=output_scale, input_scale=input_global_scale, ) @@ -1047,6 +1041,7 @@ def __init__( super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) def register(self, pm_pass: PatternMatcherPass): def get_inputs(): @@ -1078,28 +1073,17 @@ def pattern( input_global_scale: torch.Tensor, ): allreduce_output = tensor_model_parallel_all_reduce(input) - - fused_add_rmsnorm_out_tuple = auto_functionalized( - RMS_ADD_OP, - input=allreduce_output, - residual=residual, - weight=weight, - epsilon=self.epsilon, - ) + rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) quant_out_tuple = auto_functionalized( STATIC_FP4_QUANT_OP, output=quant_result, - input=fused_add_rmsnorm_out_tuple[1], + input=rms, output_scale=output_scale, input_scale=input_global_scale, ) # quant_out, allreduce_output, output_scale - return ( - quant_out_tuple[1], - fused_add_rmsnorm_out_tuple[2], - quant_out_tuple[2], - ) + return quant_out_tuple[1], residual, quant_out_tuple[2] def replacement( quant_result: torch.Tensor, From c8675ffdbcf5167da76fae8ffa6d3ffdd0e30146 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 2 Oct 2025 22:18:24 -0400 Subject: [PATCH 032/137] log depyf folder, fix context for TestBackend, fix pattern dump MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/backend.py | 18 +++++++++++------- vllm/compilation/monitor.py | 1 + vllm/compilation/vllm_inductor_pass.py | 3 ++- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/tests/compile/backend.py b/tests/compile/backend.py index ac62040287d2..a16ab9f15c9f 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -3,6 +3,7 @@ import weakref from collections.abc import Sequence +from contextlib import nullcontext from copy import deepcopy from typing import Callable, Union @@ -16,6 +17,9 @@ from vllm.compilation.pass_manager import with_pattern_match_debug from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import VllmConfig, get_current_vllm_config +from vllm.logger import init_logger + +logger = init_logger("vllm.tests.compile.backend") class LazyInitPass(InductorPass): @@ -55,16 +59,19 @@ def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]): self.inductor_config["post_grad_custom_post_pass"] = self.post_pass if debug_dump_path := vllm_config.compile_debug_dump_path(): - self.ctx = depyf.prepare_debug(debug_dump_path.as_posix()) - self.ctx.__enter__() + logger.debug("Dumping depyf output to %s", debug_dump_path) + self.debug_ctx = depyf.prepare_debug(debug_dump_path.as_posix()) else: - self.ctx = None + self.debug_ctx = nullcontext() def __call__(self, graph: fx.GraphModule, example_inputs): self.graph_pre_compile = deepcopy(graph) from torch._inductor.compile_fx import compile_fx - return compile_fx(graph, example_inputs, config_patches=self.inductor_config) + with self.debug_ctx: + return compile_fx( + graph, example_inputs, config_patches=self.inductor_config + ) @with_pattern_match_debug def post_pass(self, graph: fx.Graph): @@ -83,9 +90,6 @@ def post_pass(self, graph: fx.Graph): # assign by reference, will reflect the final state of the graph self.final_graph = graph - if self.ctx is not None: - self.ctx.__exit__(None, None, None) - def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True): for op in ops: num_pre = len(list(find_op_nodes(op, self.graph_pre_pass))) diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py index d3c437795fab..f9a189b7c77d 100644 --- a/vllm/compilation/monitor.py +++ b/vllm/compilation/monitor.py @@ -22,6 +22,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig): import depyf path.mkdir(parents=True, exist_ok=True) + logger.debug("Dumping depyf output to %s", path) global context_manager context_manager = depyf.prepare_debug(path.as_posix()) context_manager.__enter__() diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index 5aa08220bc2d..b7b3c98eb4ed 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -115,7 +115,8 @@ def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass): f" please add to dump_patterns if there are any errors.\n\n" f"from torch._higher_order_ops.auto_functionalize import " f"auto_functionalized as auto_functionalized\n" - f"from torch._inductor.pattern_matcher import *", + f"from torch._inductor.pattern_matcher import *\n" + f"vllm = torch.ops.vllm", file=f, ) From d3f95feda3df0127fb40a33828bd880cee2c9c54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 3 Oct 2025 11:38:39 -0400 Subject: [PATCH 033/137] fullgraph allreduce test update requirements MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_full_graph.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index d5d22844a223..beaeed30f004 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -19,6 +19,7 @@ from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer +from vllm.utils.flashinfer import has_flashinfer from ..utils import create_new_process_for_each_test, flat_product, multi_gpu_test @@ -314,8 +315,9 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) @pytest.mark.skipif( not current_platform.is_cuda() - or not current_platform.has_device_capability((10, 0)), - reason="allreduce+rmsnorm fusion only supported on blackwell", + or not has_flashinfer() + or not current_platform.has_device_capability(90), + reason="allreduce+rmsnorm fusion requires flashinfer", ) def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm( model_name, From 4dbfcf7017116e622f7365beb9fc6562076fd53d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 3 Oct 2025 11:49:24 -0400 Subject: [PATCH 034/137] Move e2e tests to new file, add to test pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- .buildkite/test-pipeline.yaml | 13 +- tests/compile/test_full_graph.py | 198 +----------------------- tests/compile/test_fusions_e2e.py | 246 ++++++++++++++++++++++++++++++ 3 files changed, 255 insertions(+), 202 deletions(-) create mode 100644 tests/compile/test_fusions_e2e.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index ebe0602a1b5d..f734526db130 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -796,8 +796,8 @@ steps: # Whisper needs spawn method to avoid deadlock - VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper -- label: Blackwell Test # 38 min - timeout_in_minutes: 60 +- label: Blackwell Test # 48 min + timeout_in_minutes: 70 working_dir: "/vllm-workspace/" gpu: b200 # optional: true @@ -810,8 +810,7 @@ steps: - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py - - vllm/compilation/fusion.py - - vllm/compilation/fusion_attn.py + - vllm/compilation/ commands: - nvidia-smi - python3 examples/offline_inference/basic/chat.py @@ -828,6 +827,8 @@ steps: - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py + - pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py + - pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py # Fusion @@ -835,8 +836,7 @@ steps: - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern - pytest -v -s tests/kernels/moe/test_flashinfer.py - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py - - pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py - - pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py + - pytest -v -s tests/compile/test_fusions_e2e.py - label: Blackwell GPT-OSS Eval timeout_in_minutes: 60 @@ -1109,6 +1109,7 @@ steps: commands: - pytest -v -s tests/distributed/test_context_parallel.py - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py + - pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm ##### RL Integration Tests ##### - label: Prime-RL Integration Test # 15min diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index beaeed30f004..402e6499b9d6 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -3,25 +3,19 @@ from __future__ import annotations -import itertools -import logging import tempfile -from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Any, Union import pytest -import regex as re import torch from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams -from vllm.attention.backends.registry import _Backend from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer -from vllm.utils.flashinfer import has_flashinfer -from ..utils import create_new_process_for_each_test, flat_product, multi_gpu_test +from ..utils import create_new_process_for_each_test def models_list(*, all: bool = True, keywords: list[str] | None = None): @@ -189,194 +183,6 @@ def test_fp8_kv_scale_compile(optimization_level: int): run_model(optimization_level, model, **model_kwargs) -MODELS_FP8: list[tuple[str, dict[str, Any], _Backend]] = [] -MODELS_FP4: list[tuple[str, dict[str, Any], _Backend]] = [] -MODELS: list[tuple[str, dict[str, Any], _Backend]] = [] # tp-only - -if current_platform.is_cuda(): - MODELS_FP8 += [ - ( - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", - {"max_model_len": 1024}, - _Backend.TRITON_ATTN, - ) - ] - - if current_platform.is_device_capability((10, 0)): - MODELS_FP8 += [ - ( - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", - {"kv_cache_dtype": "fp8", "max_model_len": 1024}, - _Backend.FLASHINFER, - ) - ] - - MODELS_FP4 += [ - ( - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", - {"kv_cache_dtype": "fp8", "max_model_len": 1024}, - _Backend.FLASHINFER, - ) - ] - - MODELS += [ - ( - "meta-llama/Llama-3.1-8B-Instruct", - {"max_model_len": 1024}, - _Backend.FLASHINFER, - ) - ] - -elif current_platform.is_rocm(): - MODELS_FP8 += [("amd/Llama-3.1-8B-Instruct-FP8-KV", {}, _Backend.TRITON_ATTN)] - -INDUCTOR_GRAPH_PARTITION = ( - [True, False] if (is_torch_equal_or_newer("2.9.0.dev")) else [False] -) - -# TODO(luka) test both in nightly -CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"] - - -@pytest.mark.parametrize( - "model_name, model_kwargs, backend, custom_ops", - # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8 - list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8)) - # quant_fp4 only has the custom impl - + list(flat_product(MODELS_FP4, [""])), -) -@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) -def test_e2e_fusion_attn_quant( - model_name: str, - model_kwargs: dict[str, Any], - backend: _Backend, - custom_ops: str, - inductor_graph_partition: bool, - caplog_mp_spawn, - monkeypatch, -): - custom_ops_list = custom_ops.split(",") if custom_ops else [] - - if inductor_graph_partition: - mode = CUDAGraphMode.FULL_AND_PIECEWISE - splitting_ops: Optional[list[str]] = None - else: - mode = CUDAGraphMode.FULL_DECODE_ONLY - splitting_ops = [] - - # Disable, compile cache to make sure custom passes run. - # Otherwise, we can't verify fusion happened through the logs. - # Log capture also doesn't work with multiprocessing yet. - monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") - - # To capture subprocess logs, we need to know whether spawn or fork is used. - # Force spawn as it is more general. - monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) - - compilation_config = CompilationConfig( - # Testing properties - custom_ops=custom_ops_list, - use_inductor_graph_partition=inductor_graph_partition, - cudagraph_mode=mode, - splitting_ops=splitting_ops, - # Common - level=CompilationLevel.PIECEWISE, - pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True), - # Inductor caches custom passes by default as well via uuid - inductor_compile_config={"force_disable_caches": True}, - ) - - with caplog_mp_spawn(logging.DEBUG) as log_holder: - run_model(compilation_config, model_name, **model_kwargs) - - assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text - - -# TODO(luka) test both in nightly -# TODO(luka) change to - -CUSTOM_OPS_RMS_NORM = ["+rms_norm"] # , "+rms_norm"] - - -def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: - for op_list in itertools.product(*custom_ops_lists): - yield ",".join(op_list) - - -@multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize( - "model_name, model_kwargs, backend, custom_ops", - # Toggle RMSNorm and QuantFP8 for FP8 models - list(flat_product(MODELS_FP8, ["+quant_fp8,+rms_norm"])) - # custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM))) # TODO - # Toggle RMSNorm for FP4 models and unquant models - + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), -) -@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) -@pytest.mark.skipif( - not current_platform.is_cuda() - or not has_flashinfer() - or not current_platform.has_device_capability(90), - reason="allreduce+rmsnorm fusion requires flashinfer", -) -def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm( - model_name, - model_kwargs, - backend, - custom_ops: str, - inductor_graph_partition: bool, - caplog_mp_spawn, - monkeypatch, -): - custom_ops_list = custom_ops.split(",") if custom_ops else [] - - if inductor_graph_partition: - mode = CUDAGraphMode.FULL_AND_PIECEWISE - splitting_ops: Optional[list[str]] = None - else: - mode = CUDAGraphMode.FULL_DECODE_ONLY - splitting_ops = [] - - # Disable, compile cache to make sure custom passes run. - # Otherwise, we can't verify fusion happened through the logs. - # Log capture also doesn't work with multiprocessing yet. - monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") - - # To capture subprocess logs, we need to know whether spawn or fork is used. - # Force spawn as it is more general. - monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) - - compilation_config = CompilationConfig( - # Testing properties - use_inductor_graph_partition=inductor_graph_partition, - cudagraph_mode=mode, - custom_ops=custom_ops_list, - splitting_ops=splitting_ops, - # Common - level=CompilationLevel.PIECEWISE, - pass_config=PassConfig( - enable_attn_fusion=True, - enable_noop=True, - enable_fi_allreduce_fusion=True, - ), - # Inductor caches custom passes by default as well via uuid - inductor_compile_config={"force_disable_caches": True}, - ) - - with caplog_mp_spawn(logging.DEBUG) as log_holder: - run_model( - compilation_config, model_name, tensor_parallel_size=2, **model_kwargs - ) - - assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text - - matches = re.findall( - r"\[collective_fusion.py:\d+] Replaced 96 patterns", log_holder.text - ) - assert len(matches) == 2, log_holder.text - - def run_model( compile_config: Union[int, CompilationConfig], model: str, **model_kwargs ): diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py new file mode 100644 index 000000000000..b0700e4e86a4 --- /dev/null +++ b/tests/compile/test_fusions_e2e.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import itertools +import logging +from collections.abc import Iterable +from typing import Any, Optional, Union + +import pytest +import regex as re + +from tests.v1.attention.utils import _Backend +from vllm import LLM, SamplingParams +from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig +from vllm.platforms import current_platform +from vllm.utils import is_torch_equal_or_newer +from vllm.utils.flashinfer import has_flashinfer + +from ..utils import flat_product, multi_gpu_test + +MODELS_FP8: list[tuple[str, dict[str, Any], _Backend]] = [] +MODELS_FP4: list[tuple[str, dict[str, Any], _Backend]] = [] +MODELS: list[tuple[str, dict[str, Any], _Backend]] = [] # tp-only + +if current_platform.is_cuda(): + MODELS_FP8 += [ + ( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + {"max_model_len": 1024}, + _Backend.TRITON_ATTN, + ) + ] + + if current_platform.is_device_capability((10, 0)): + MODELS_FP8 += [ + ( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + {"kv_cache_dtype": "fp8", "max_model_len": 1024}, + _Backend.FLASHINFER, + ) + ] + + MODELS_FP4 += [ + ( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + {"kv_cache_dtype": "fp8", "max_model_len": 1024}, + _Backend.FLASHINFER, + ) + ] + + MODELS += [ + ( + "meta-llama/Llama-3.1-8B-Instruct", + {"max_model_len": 1024}, + _Backend.FLASHINFER, + ) + ] + +elif current_platform.is_rocm(): + MODELS_FP8 += [("amd/Llama-3.1-8B-Instruct-FP8-KV", {}, _Backend.TRITON_ATTN)] + +INDUCTOR_GRAPH_PARTITION = ( + [True, False] if (is_torch_equal_or_newer("2.9.0.dev")) else [False] +) + +# TODO(luka) test both in nightly +CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"] + + +@pytest.mark.parametrize( + "model_name, model_kwargs, backend, custom_ops", + # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8 + list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8)) + # quant_fp4 only has the custom impl + + list(flat_product(MODELS_FP4, [""])), +) +@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +def test_attn_quant( + model_name: str, + model_kwargs: dict[str, Any], + backend: _Backend, + custom_ops: str, + inductor_graph_partition: bool, + caplog_mp_spawn, + monkeypatch, +): + custom_ops_list = custom_ops.split(",") if custom_ops else [] + + if inductor_graph_partition: + mode = CUDAGraphMode.FULL_AND_PIECEWISE + splitting_ops: Optional[list[str]] = None + else: + mode = CUDAGraphMode.FULL_DECODE_ONLY + splitting_ops = [] + + # Disable, compile cache to make sure custom passes run. + # Otherwise, we can't verify fusion happened through the logs. + # Log capture also doesn't work with multiprocessing yet. + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + # To capture subprocess logs, we need to know whether spawn or fork is used. + # Force spawn as it is more general. + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) + + compilation_config = CompilationConfig( + # Testing properties + custom_ops=custom_ops_list, + use_inductor_graph_partition=inductor_graph_partition, + cudagraph_mode=mode, + splitting_ops=splitting_ops, + # Common + level=CompilationLevel.PIECEWISE, + pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True), + # Inductor caches custom passes by default as well via uuid + inductor_compile_config={"force_disable_caches": True}, + ) + + with caplog_mp_spawn(logging.DEBUG) as log_holder: + run_model(compilation_config, model_name, **model_kwargs) + + assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text + + +# TODO(luka) test both in nightly +# TODO(luka) change to - +CUSTOM_OPS_RMS_NORM = ["+rms_norm"] # , "+rms_norm"] + + +def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: + for op_list in itertools.product(*custom_ops_lists): + yield ",".join(op_list) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "model_name, model_kwargs, backend, custom_ops", + # Toggle RMSNorm and QuantFP8 for FP8 models + list(flat_product(MODELS_FP8, ["+quant_fp8,+rms_norm"])) + # custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM))) # TODO + # Toggle RMSNorm for FP4 models and unquant models + + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), +) +@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +@pytest.mark.skipif( + not current_platform.is_cuda() + or not has_flashinfer() + or not current_platform.has_device_capability(90), + reason="allreduce+rmsnorm fusion requires flashinfer", +) +def test_tp2_attn_quant_allreduce_rmsnorm( + model_name, + model_kwargs, + backend, + custom_ops: str, + inductor_graph_partition: bool, + caplog_mp_spawn, + monkeypatch, +): + custom_ops_list = custom_ops.split(",") if custom_ops else [] + + if inductor_graph_partition: + mode = CUDAGraphMode.FULL_AND_PIECEWISE + splitting_ops: Optional[list[str]] = None + else: + mode = CUDAGraphMode.FULL_DECODE_ONLY + splitting_ops = [] + + # Disable, compile cache to make sure custom passes run. + # Otherwise, we can't verify fusion happened through the logs. + # Log capture also doesn't work with multiprocessing yet. + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + # To capture subprocess logs, we need to know whether spawn or fork is used. + # Force spawn as it is more general. + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) + + compilation_config = CompilationConfig( + # Testing properties + use_inductor_graph_partition=inductor_graph_partition, + cudagraph_mode=mode, + custom_ops=custom_ops_list, + splitting_ops=splitting_ops, + # Common + level=CompilationLevel.PIECEWISE, + pass_config=PassConfig( + enable_attn_fusion=True, + enable_noop=True, + enable_fi_allreduce_fusion=True, + ), + # Inductor caches custom passes by default as well via uuid + inductor_compile_config={"force_disable_caches": True}, + ) + + with caplog_mp_spawn(logging.DEBUG) as log_holder: + run_model( + compilation_config, model_name, tensor_parallel_size=2, **model_kwargs + ) + + assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text + + matches = re.findall( + r"\[collective_fusion.py:\d+] Replaced 96 patterns", log_holder.text + ) + assert len(matches) == 2, log_holder.text + + +def run_model( + compile_config: Union[int, CompilationConfig], model: str, **model_kwargs +): + compilation_config = ( + compile_config + if isinstance(compile_config, CompilationConfig) + else CompilationConfig(level=compile_config) + ) + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0) + # Allow override from model_kwargs + model_kwargs = {"tensor_parallel_size": 1, **model_kwargs} + model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs} + + # No cudagraphs by default + if compilation_config.cudagraph_mode is None: + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + + llm = LLM( + model=model, + compilation_config=compilation_config, + **model_kwargs, + ) + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") From 31d0127c71e73b6ee257cf564833681da71cff1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 3 Oct 2025 13:01:13 -0400 Subject: [PATCH 035/137] Add e2e fusions to fullgraph test (should work with Triton backend), disable without flashinfer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- .buildkite/test-pipeline.yaml | 5 +++-- tests/compile/test_fusions_e2e.py | 4 +--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f734526db130..85616de5b197 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -416,8 +416,8 @@ steps: - pytest -v -s compile/test_basic_correctness.py - pytest -v -s compile/piecewise/ -- label: PyTorch Fullgraph Test # 20min - timeout_in_minutes: 30 +- label: PyTorch Fullgraph Test # 22min + timeout_in_minutes: 35 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: @@ -425,6 +425,7 @@ steps: - tests/compile commands: - pytest -v -s compile/test_full_graph.py + - pytest -v -s compile/test_fusions_e2e.py - label: Kernels Core Operation Test # 48min timeout_in_minutes: 75 diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index b0700e4e86a4..cbeaa8bcb3f3 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -33,7 +33,7 @@ ) ] - if current_platform.is_device_capability((10, 0)): + if current_platform.is_device_capability((10, 0)) and has_flashinfer(): MODELS_FP8 += [ ( "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", @@ -97,7 +97,6 @@ def test_attn_quant( # Disable, compile cache to make sure custom passes run. # Otherwise, we can't verify fusion happened through the logs. - # Log capture also doesn't work with multiprocessing yet. monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") # To capture subprocess logs, we need to know whether spawn or fork is used. @@ -170,7 +169,6 @@ def test_tp2_attn_quant_allreduce_rmsnorm( # Disable, compile cache to make sure custom passes run. # Otherwise, we can't verify fusion happened through the logs. - # Log capture also doesn't work with multiprocessing yet. monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") # To capture subprocess logs, we need to know whether spawn or fork is used. From c653d24a39d32f4e227fd845335b539366ea35e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 3 Oct 2025 23:39:23 -0400 Subject: [PATCH 036/137] Fix spelling, precommit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/conftest.py | 2 +- vllm/compilation/matcher_utils.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 7a907b2ac79f..76bbccc29534 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1119,7 +1119,7 @@ def __init__(self): def caplog_mp_spawn(tmp_path, monkeypatch): """ This fixture enables capturing logs from a forked MP subprocess. - It does not require caplog_vllm (but it only contains log + It does not require caplog_vllm (but it only contains logs from the child). By default, subprocess logs do not go through the parent process. We instead add a FileHandler to the config so the spawned child process diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index d3603372d69f..55fbeadc22fe 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -57,6 +57,10 @@ def empty(self, *args, **kws): def empty_f32(self, *args, **kws): return torch.empty(*args, dtype=torch.float32, device="cuda", **kws) + def inputs(self) -> list[torch.Tensor]: + """Utility for inputs to the pattern""" + raise NotImplementedError + class MatcherRMSNorm(MatcherCustomOp): def __init__(self, epsilon: float, enabled: Optional[bool] = None): From 1756f6755970f14c3cf643cb522c69208f80c3b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Sat, 4 Oct 2025 00:06:13 -0400 Subject: [PATCH 037/137] add back fp4 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/compilation/matcher_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 55fbeadc22fe..fe558b7acac2 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -16,7 +16,9 @@ kFp8DynamicTensorSym, kFp8DynamicTokenSym, kFp8StaticTensorSym, + kNvfp4Quant, ) +from vllm.platforms import current_platform RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default @@ -27,10 +29,8 @@ kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } -# TODO -# if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): -# QUANT_OPS[ -# kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 +if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): + QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 class MatcherCustomOp(ABC): From 5619bc38bc781cd70f8f3c12124fee3042ff7437 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 9 Oct 2025 21:42:34 -0400 Subject: [PATCH 038/137] clean up e2e tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusions_e2e.py | 147 ++++++++++++++++++++---------- 1 file changed, 99 insertions(+), 48 deletions(-) diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index cbeaa8bcb3f3..6e4893cd0f66 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -10,6 +10,7 @@ import pytest import regex as re +from black.cache import NamedTuple from tests.v1.attention.utils import _Backend from vllm import LLM, SamplingParams @@ -20,72 +21,111 @@ from ..utils import flat_product, multi_gpu_test -MODELS_FP8: list[tuple[str, dict[str, Any], _Backend]] = [] -MODELS_FP4: list[tuple[str, dict[str, Any], _Backend]] = [] -MODELS: list[tuple[str, dict[str, Any], _Backend]] = [] # tp-only + +class ModelBackendTestCase(NamedTuple): + model_name: str + model_kwargs: dict[str, Any] + backend: _Backend + attention_fusions: int + allreduce_fusions: Optional[int] = None + + +MODELS_FP8: list[ModelBackendTestCase] = [] +MODELS_FP4: list[ModelBackendTestCase] = [] +MODELS: list[ModelBackendTestCase] = [] # tp-only if current_platform.is_cuda(): - MODELS_FP8 += [ - ( - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", - {"max_model_len": 1024}, - _Backend.TRITON_ATTN, - ) + MODELS_FP8 = [ + ModelBackendTestCase( + model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.TRITON_ATTN, + attention_fusions=48, + allreduce_fusions=96, + ), + ModelBackendTestCase( + model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), + backend=_Backend.FLASHINFER, + attention_fusions=48, + allreduce_fusions=96, + ), ] - if current_platform.is_device_capability((10, 0)) and has_flashinfer(): - MODELS_FP8 += [ - ( - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", - {"kv_cache_dtype": "fp8", "max_model_len": 1024}, - _Backend.FLASHINFER, - ) - ] - - MODELS_FP4 += [ - ( - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", - {"kv_cache_dtype": "fp8", "max_model_len": 1024}, - _Backend.FLASHINFER, - ) - ] - - MODELS += [ - ( - "meta-llama/Llama-3.1-8B-Instruct", - {"max_model_len": 1024}, - _Backend.FLASHINFER, - ) - ] + MODELS_FP4 = [ + ModelBackendTestCase( + model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), + backend=_Backend.FLASHINFER, + attention_fusions=48, + allreduce_fusions=96, + ), + ] -elif current_platform.is_rocm(): - MODELS_FP8 += [("amd/Llama-3.1-8B-Instruct-FP8-KV", {}, _Backend.TRITON_ATTN)] + # TP only + MODELS = [ + ModelBackendTestCase( + model_name="meta-llama/Llama-3.1-8B-Instruct", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.TRITON_ATTN, + attention_fusions=0, + allreduce_fusions=64, + ), + ] -INDUCTOR_GRAPH_PARTITION = ( - [True, False] if (is_torch_equal_or_newer("2.9.0.dev")) else [False] -) +elif current_platform.is_rocm(): + MODELS_FP8 = [ + ModelBackendTestCase( + model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.TRITON_ATTN, + attention_fusions=32, + ), + ModelBackendTestCase( + model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.ROCM_ATTN, + attention_fusions=32, + ), + ModelBackendTestCase( + model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.ROCM_AITER_FA, # TODO ROCM_AITER_UNIFIED_ATTN + attention_fusions=32, + ), + ] # TODO(luka) test both in nightly CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"] @pytest.mark.parametrize( - "model_name, model_kwargs, backend, custom_ops", + "model_name, model_kwargs, backend, " + "attention_fusions, allreduce_fusions, custom_ops", # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8 list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8)) # quant_fp4 only has the custom impl + list(flat_product(MODELS_FP4, [""])), ) -@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +@pytest.mark.parametrize("inductor_graph_partition", [True, False]) def test_attn_quant( model_name: str, model_kwargs: dict[str, Any], backend: _Backend, + attention_fusions: int, + allreduce_fusions: int, custom_ops: str, inductor_graph_partition: bool, caplog_mp_spawn, monkeypatch, ): + if backend == _Backend.FLASHINFER and ( + not current_platform.is_device_capability((10, 0)) or not has_flashinfer() + ): + pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") + if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("Inductor graph partition requires torch>=2.9") + custom_ops_list = custom_ops.split(",") if custom_ops else [] if inductor_graph_partition: @@ -120,7 +160,9 @@ def test_attn_quant( with caplog_mp_spawn(logging.DEBUG) as log_holder: run_model(compilation_config, model_name, **model_kwargs) - assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text + assert f"Fused quant onto {attention_fusions} attention nodes" in log_holder.text, ( + log_holder.text + ) # TODO(luka) test both in nightly @@ -135,14 +177,15 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( - "model_name, model_kwargs, backend, custom_ops", + "model_name, model_kwargs, backend, " + "attention_fusions, allreduce_fusions, custom_ops", # Toggle RMSNorm and QuantFP8 for FP8 models list(flat_product(MODELS_FP8, ["+quant_fp8,+rms_norm"])) # custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM))) # TODO # Toggle RMSNorm for FP4 models and unquant models + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), ) -@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +@pytest.mark.parametrize("inductor_graph_partition", [True, False]) @pytest.mark.skipif( not current_platform.is_cuda() or not has_flashinfer() @@ -150,14 +193,19 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: reason="allreduce+rmsnorm fusion requires flashinfer", ) def test_tp2_attn_quant_allreduce_rmsnorm( - model_name, - model_kwargs, - backend, + model_name: str, + model_kwargs: dict, + backend: _Backend, + attention_fusions: int, + allreduce_fusions: int, custom_ops: str, inductor_graph_partition: bool, caplog_mp_spawn, monkeypatch, ): + if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("Inductor graph partition requires torch>=2.9") + custom_ops_list = custom_ops.split(",") if custom_ops else [] if inductor_graph_partition: @@ -198,10 +246,13 @@ def test_tp2_attn_quant_allreduce_rmsnorm( compilation_config, model_name, tensor_parallel_size=2, **model_kwargs ) - assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text + assert f"Fused quant onto {attention_fusions} attention nodes" in log_holder.text, ( + log_holder.text + ) matches = re.findall( - r"\[collective_fusion.py:\d+] Replaced 96 patterns", log_holder.text + rf"\[collective_fusion.py:\d+] Replaced {allreduce_fusions} patterns", + log_holder.text, ) assert len(matches) == 2, log_holder.text From 32989d804e03ec0a3e9c97ee84b684e4683d0a67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 10 Oct 2025 13:49:09 -0400 Subject: [PATCH 039/137] add pattern for final allreduce in model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/compilation/collective_fusion.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index d5a3fcde03b6..d6ad0cd38f94 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -775,6 +775,18 @@ def replacement( pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass ) + # Same pattern, but only return the output and not residual + # (helpful for end of graph where residual is not used again) + first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0] + + pm.register_replacement( + first_return_only(pattern), + first_return_only(replacement), + self.get_inputs(), + pm.fwd_only, + pm_pass, + ) + class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): """ From 46ee6267b7b8f78fe038bbcda5650a26a2031133 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 10 Oct 2025 13:51:13 -0400 Subject: [PATCH 040/137] add more comprehensive testing for quantfp8 (-rmsnorm+-quant still failing) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_all_reduce.py | 89 +++++++++++++++++-------- 1 file changed, 62 insertions(+), 27 deletions(-) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 657ebc4a28a6..fa0293497aba 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -26,8 +26,8 @@ ) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp, GroupShape, - QuantFP8, ) from vllm.platforms import current_platform from vllm.utils import update_environment_variables @@ -43,9 +43,9 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): self.eps = eps self.norm = RMSNorm(hidden_size, eps) - def forward(self, hidden_states, residual): - view = hidden_states.reshape(-1, self.hidden_size) - all_reduce = tensor_model_parallel_all_reduce(view) + def forward(self, x): + z = torch.relu(x) + all_reduce = tensor_model_parallel_all_reduce(z) norm = self.norm(all_reduce) return norm @@ -63,9 +63,9 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): self.eps = eps self.norm = RMSNorm(hidden_size, eps) - def forward(self, hidden_states, residual): - view = hidden_states.reshape(-1, self.hidden_size) - all_reduce = tensor_model_parallel_all_reduce(view) + def forward(self, hidden_states): + z = residual = torch.relu(hidden_states) + all_reduce = tensor_model_parallel_all_reduce(z) norm, res = self.norm(all_reduce, residual) return norm, res @@ -77,21 +77,53 @@ def ops_in_model_after(self): return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] -class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module): +class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps - self.norm = RMSNorm(hidden_size, eps) - self.quant_fp8 = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) - self.scale = torch.rand(1, dtype=torch.float32) + self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] + self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + self.w = [ + torch.rand(hidden_size, hidden_size) + .to(dtype=current_platform.fp8_dtype()) + .t() + for _ in range(3) + ] - def forward(self, hidden_states, residual): - view = hidden_states.reshape(-1, self.hidden_size) - all_reduce = tensor_model_parallel_all_reduce(view) - norm_output, residual_output = self.norm(all_reduce, residual) - quant_out, _ = self.quant_fp8(norm_output, self.scale) - return quant_out, residual_output + self.fp8_linear = Fp8LinearOp( + act_quant_static=True, + act_quant_group_shape=GroupShape.PER_TENSOR, + ) + + self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + + def forward(self, hidden_states): + # avoid having graph input be an arg to a pattern directly + z = torch.relu(hidden_states) + x = resid = tensor_model_parallel_all_reduce(z) + y = self.norm[0](x) + + z2 = self.fp8_linear.apply( + y, self.w[0], self.wscale[0], input_scale=self.scale[0] + ) + + x2 = tensor_model_parallel_all_reduce(z2) + y2, resid = self.norm[1](x2, resid) + + z3 = self.fp8_linear.apply( + y2, self.w[1], self.wscale[1], input_scale=self.scale[1] + ) + + x3 = tensor_model_parallel_all_reduce(z3) + y3, resid = self.norm[2](x3, resid) # use resid here + + z4 = self.fp8_linear.apply( + y3, self.w[2], self.wscale[2], input_scale=self.scale[2] + ) + x4 = tensor_model_parallel_all_reduce(z4) + y4, resid = self.norm[3](x4, resid) # use resid here + return y4 def ops_in_model_after(self): return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] @@ -100,7 +132,7 @@ def ops_in_model_before(self): return [ torch.ops.vllm.all_reduce.default, torch.ops._C.static_scaled_fp8_quant.default - if self.quant_fp8.enabled() + if self.fp8_linear.quant_fp8.enabled() else torch.ops.aten.reciprocal.default, ] @@ -120,11 +152,10 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): rounded_n = round_up(scale_n, 4) self.output_scale = torch.empty((rounded_m, rounded_n // 4), dtype=torch.int32) - def forward(self, hidden_states, residual): - view = hidden_states.reshape(-1, self.hidden_size) - all_reduce = tensor_model_parallel_all_reduce(view) + def forward(self, hidden_states): + z = residual = torch.relu(hidden_states) + all_reduce = tensor_model_parallel_all_reduce(z) norm_output, residual_output = self.norm(all_reduce, residual) - norm_output = norm_output.reshape(-1, norm_output.shape[-1]) torch.ops._C.scaled_fp4_quant( self.output, norm_output, self.output_scale, self.scale ) @@ -146,8 +177,8 @@ def ops_in_model_before(self): [ (TestAllReduceRMSNormModel, False), (TestAllReduceFusedAddRMSNormModel, False), - (TestAllReduceFusedAddRMSNormStaticQuantFP8Model, True), - (TestAllReduceFusedAddRMSNormStaticQuantFP8Model, False), + (TestAllReduceRMSNormStaticQuantFP8Model, True), + (TestAllReduceRMSNormStaticQuantFP8Model, False), (TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False), ], ) @@ -269,12 +300,16 @@ def all_reduce_fusion_pass_on_test_model( model = test_model_cls(hidden_size, token_num) hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) - residual = torch.randn((token_num, hidden_size), requires_grad=False) compiled_model = torch.compile(model, backend=backend) - compiled_model(hidden_states, residual) + compiled_model(hidden_states) - assert all_reduce_fusion_pass.matched_count == 1 + # TODO cleanup + expected = 4 if test_model_cls is TestAllReduceRMSNormStaticQuantFP8Model else 1 + + assert all_reduce_fusion_pass.matched_count == expected, ( + f"{all_reduce_fusion_pass.matched_count=}, {expected=}" + ) backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) backend.check_after_ops(model.ops_in_model_after()) del all_reduce_fusion_pass From a1c7fdb32ad96648ebfdf94c79069d448ffbb2a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 10 Oct 2025 16:13:42 -0400 Subject: [PATCH 041/137] add more comprehensive testing for allreduce-rmsnorm, fix fp4 (-rmsnorm still failing) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_all_reduce.py | 97 +++++++++++++++---------- vllm/compilation/collective_fusion.py | 16 +--- 2 files changed, 59 insertions(+), 54 deletions(-) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index fa0293497aba..0c9d584ddf46 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -6,6 +6,7 @@ import torch import vllm.envs as envs +from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.compilation.collective_fusion import AllReduceFusionPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.noop_elimination import NoOpEliminationPass @@ -41,34 +42,30 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps - self.norm = RMSNorm(hidden_size, eps) + self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] + self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)] def forward(self, x): + # avoid having graph input be an arg to a pattern directly z = torch.relu(x) - all_reduce = tensor_model_parallel_all_reduce(z) - norm = self.norm(all_reduce) - return norm + x = resid = tensor_model_parallel_all_reduce(z) + y = self.norm[0](x) - def ops_in_model_before(self): - return [torch.ops.vllm.all_reduce.default] + z2 = torch.mm(y, self.w[0]) + x2 = tensor_model_parallel_all_reduce(z2) - def ops_in_model_after(self): - return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] + y2, resid = self.norm[1](x2, resid) + z3 = torch.mm(y2, self.w[1]) + x3 = tensor_model_parallel_all_reduce(z3) -class TestAllReduceFusedAddRMSNormModel(torch.nn.Module): - def __init__(self, hidden_size=16, token_num=16, eps=1e-6): - super().__init__() - self.hidden_size = hidden_size - self.eps = eps - self.norm = RMSNorm(hidden_size, eps) + y3, resid = self.norm[2](x3, resid) - def forward(self, hidden_states): - z = residual = torch.relu(hidden_states) - all_reduce = tensor_model_parallel_all_reduce(z) - norm, res = self.norm(all_reduce, residual) + z4 = torch.mm(y3, self.w[2]) + x4 = tensor_model_parallel_all_reduce(z4) - return norm, res + y4, resid = self.norm[3](x4, resid) + return y4 def ops_in_model_before(self): return [torch.ops.vllm.all_reduce.default] @@ -142,24 +139,48 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps - self.norm = RMSNorm(hidden_size, eps) - self.scale = torch.rand(1, dtype=torch.float32) - self.output = torch.empty((token_num, hidden_size), dtype=torch.float32) + self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] + + self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)] + self.agscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + wgscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + self.alpha = [1 / (w * a) for w, a in zip(wgscale, self.agscale)] - round_up = lambda x, y: (x + y - 1) // y * y - rounded_m = round_up(token_num, 128) - scale_n = hidden_size // 16 - rounded_n = round_up(scale_n, 4) - self.output_scale = torch.empty((rounded_m, rounded_n // 4), dtype=torch.int32) + wq_gen, wscale_gen = zip( + *(scaled_fp4_quant(w, wg) for w, wg in zip(self.w, wgscale)) + ) + self.wq, self.wscale = list(wq_gen), list(wscale_gen) + print(f"{self.wq=}, {self.wscale=}") def forward(self, hidden_states): - z = residual = torch.relu(hidden_states) - all_reduce = tensor_model_parallel_all_reduce(z) - norm_output, residual_output = self.norm(all_reduce, residual) - torch.ops._C.scaled_fp4_quant( - self.output, norm_output, self.output_scale, self.scale + # avoid having graph input be an arg to a pattern directly + z = torch.relu(hidden_states) + x = resid = tensor_model_parallel_all_reduce(z) + y = self.norm[0](x) + + yq, y_scale = scaled_fp4_quant(y, self.agscale[0]) + z2 = cutlass_scaled_fp4_mm( + yq, self.wq[0], y_scale, self.wscale[0], self.alpha[0], out_dtype=y.dtype + ) + + x2 = tensor_model_parallel_all_reduce(z2) + y2, resid = self.norm[1](x2, resid) + + yq2, y_scale2 = scaled_fp4_quant(y2, self.agscale[1]) + z3 = cutlass_scaled_fp4_mm( + yq2, self.wq[1], y_scale2, self.wscale[1], self.alpha[1], out_dtype=y2.dtype + ) + + x3 = tensor_model_parallel_all_reduce(z3) + y3, resid = self.norm[2](x3, resid) # use resid here + + yq3, y_scale3 = scaled_fp4_quant(y3, self.agscale[2]) + z4 = cutlass_scaled_fp4_mm( + yq3, self.wq[2], y_scale3, self.wscale[2], self.alpha[2], out_dtype=y3.dtype ) - return self.output, residual_output, self.output_scale + x4 = tensor_model_parallel_all_reduce(z4) + y4, resid = self.norm[3](x4, resid) # use resid here + return y4 def ops_in_model_after(self): return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] @@ -176,7 +197,6 @@ def ops_in_model_before(self): "test_model, enable_quant_fp8", [ (TestAllReduceRMSNormModel, False), - (TestAllReduceFusedAddRMSNormModel, False), (TestAllReduceRMSNormStaticQuantFP8Model, True), (TestAllReduceRMSNormStaticQuantFP8Model, False), (TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False), @@ -184,7 +204,7 @@ def ops_in_model_before(self): ) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [8]) -@pytest.mark.parametrize("hidden_size", [16]) +@pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("enable_rms_norm", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @@ -304,11 +324,8 @@ def all_reduce_fusion_pass_on_test_model( compiled_model = torch.compile(model, backend=backend) compiled_model(hidden_states) - # TODO cleanup - expected = 4 if test_model_cls is TestAllReduceRMSNormStaticQuantFP8Model else 1 - - assert all_reduce_fusion_pass.matched_count == expected, ( - f"{all_reduce_fusion_pass.matched_count=}, {expected=}" + assert all_reduce_fusion_pass.matched_count == 4, ( + f"{all_reduce_fusion_pass.matched_count=}" ) backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) backend.check_after_ops(model.ops_in_model_after()) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index d6ad0cd38f94..cc4f2152e1c5 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -960,10 +960,6 @@ def __init__( def register(self, pm_pass: PatternMatcherPass): def get_inputs(): input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype) - - rmsnorm_result = torch.empty( - [1, 16, 16], device=self.device, dtype=self.dtype - ) quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8) input_global_scale = torch.empty( [1, 1], device=self.device, dtype=torch.float32 @@ -971,18 +967,10 @@ def get_inputs(): weight = torch.empty([16], device=self.device, dtype=self.dtype) output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32) - return [ - input, - rmsnorm_result, - quant_result, - weight, - input_global_scale, - output_scale, - ] + return [input, quant_result, weight, input_global_scale, output_scale] def pattern( input: torch.Tensor, - rmsnorm_result: torch.Tensor, quant_result: torch.Tensor, weight: torch.Tensor, input_global_scale: torch.Tensor, @@ -1003,13 +991,13 @@ def pattern( def replacement( input: torch.Tensor, - result_rms: torch.Tensor, quant_result: torch.Tensor, weight: torch.Tensor, input_global_scale: torch.Tensor, output_scale: torch.Tensor, ): residual = torch.zeros_like(input) + result_rms = torch.empty_like(input) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, From c3264d849f1ca0d8736e39c0c25f6420930105a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 10 Oct 2025 18:36:15 -0400 Subject: [PATCH 042/137] Fix partial match rmsnorm+quant, fix allreduce+rmsnorm match MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion.py | 40 +++++++++++++++++++++++-- vllm/compilation/fusion.py | 15 +++++----- vllm/compilation/fx_utils.py | 16 ++++++++-- vllm/model_executor/layers/layernorm.py | 5 +++- 4 files changed, 62 insertions(+), 14 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 13cffbe087c6..4ab450827609 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -5,7 +5,9 @@ import torch import vllm.plugins -from vllm.compilation.fusion import RMSNormQuantFusionPass +from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass +from vllm.compilation.fx_utils import find_op_nodes +from vllm.compilation.matcher_utils import QUANT_OPS from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import ( @@ -33,6 +35,9 @@ FP8_DTYPE = current_platform.fp8_dtype() +RMS_OP = torch.ops._C.rms_norm.default +RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default + class TestModel(torch.nn.Module): def __init__( @@ -50,7 +55,7 @@ def __init__( self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN quant_scale = ScaleDesc(torch.float32, static, group_shape) - self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) + self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) if static: self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] else: @@ -93,6 +98,22 @@ def forward(self, x): y4, resid = self.norm[3](x4, resid) # use resid here return y4 + def ops_in_model_after(self): + return [ + FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)], + FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)], + ] + + def ops_in_model_before(self): + return ( + [QUANT_OPS[self.quant_key]] + if self.enable_quant_fp8 + else [torch.ops.aten.reciprocal] + ) + + def ops_in_model_before_partial(self): + return [RMS_OP, RMS_ADD_OP] if self.enable_rms_norm else [torch.ops.aten.rsqrt] + @pytest.mark.parametrize("dtype", [torch.float16]) # , torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [64]) @@ -164,3 +185,18 @@ def test_fusion_rmsnorm_quant( torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) assert fusion_pass.matched_count == 3 + backend.check_before_ops(model.ops_in_model_before()) + backend.check_before_ops( + model.ops_in_model_before_partial(), fully_replaced=False + ) + backend.check_after_ops(model.ops_in_model_after()) + + # If RMSNorm custom op is disabled (native/torch impl used), + # there's a risk that the fused add doesn't get included in the + # replacement and only the rms part gets fused with quant. + # Hence, we check only 2 add nodes are left (final fused rmsnorm add). + if not enable_rms_norm: + n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g)) + # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each) + assert n_add_nodes(backend.graph_pre_pass) == 7 + assert n_add_nodes(backend.graph_post_pass) == 2 diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 883743b635a8..9ace7a8cf050 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -94,9 +94,6 @@ def __init__(self, epsilon: float, key: FusedRMSQuantKey): self.epsilon = epsilon self.quant_dtype = key.quant.dtype - assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}" - self.QUANT_OP = QUANT_OPS[key.quant] - assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" self.FUSED_OP = FUSED_OPS[key] @@ -334,23 +331,25 @@ def __init__(self, config: VllmConfig): pass_name="rmsnorm_quant_fusion_pass" ) + # Make sure fused add patterns are before simple rms norm, + # as the latter is a subset of the former in torch ops for epsilon in [1e-5, 1e-6]: - # Fuse rms_norm + static fp8 quant - RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) - # Fuse fused_add_rms_norm + static fp8 quant FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( self.patterns ) - # Fuse rms_norm + dynamic per-token fp8 quant - RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) + # Fuse rms_norm + static fp8 quant + RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) # Fuse fused_add_rms_norm + dynamic per-token fp8 quant FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( self.patterns ) + # Fuse rms_norm + dynamic per-token fp8 quant + RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) + self.dump_patterns(config, self.patterns) @VllmInductorPass.time_and_log diff --git a/vllm/compilation/fx_utils.py b/vllm/compilation/fx_utils.py index 114b53c74c48..3209c49eba26 100644 --- a/vllm/compilation/fx_utils.py +++ b/vllm/compilation/fx_utils.py @@ -3,11 +3,11 @@ import operator from collections.abc import Iterable, Iterator -from typing import Optional +from typing import Optional, Union from torch import fx from torch._higher_order_ops.auto_functionalize import auto_functionalized -from torch._ops import OpOverload +from torch._ops import OpOverload, OpOverloadPacket def is_func(node: fx.Node, target) -> bool: @@ -67,7 +67,17 @@ def find_getitem(node: fx.Node, idx: int) -> fx.Node: # An auto-functionalization-aware utility for finding nodes with a specific op -def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]: +# Also handles op overload packets and finds all overloads +def find_op_nodes( + op: Union[OpOverload, OpOverloadPacket], graph: fx.Graph +) -> Iterator[fx.Node]: + if isinstance(op, OpOverloadPacket): + for overload in op.overloads(): + overload_op = getattr(op, overload) + yield from find_op_nodes(overload_op, graph) + return + + assert isinstance(op, OpOverload) if not op._schema.is_mutable: yield from graph.find_nodes(op="call_function", target=op) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 7e15efab379b..b70ea33f2cd2 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -195,7 +195,10 @@ def forward_static( orig_dtype = x.dtype x = x.to(torch.float32) if residual is not None: - x = x + residual.to(torch.float32) + # residual promoted f16->f32 automatically, + # otherwise Inductor eliminates the casts to and from f16, + # increasing memory usage (and complicating pattern matching) + x = x + residual residual = x.to(orig_dtype) if x.shape[-1] != hidden_size: From 095277ca89b85a0ae7952218c97f0ababa34f30b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 10 Oct 2025 19:03:18 -0400 Subject: [PATCH 043/137] Simplify matcher utils by using RMSNorm.forward_static MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/compilation/matcher_utils.py | 38 ++++--------------------- vllm/model_executor/layers/layernorm.py | 3 +- 2 files changed, 8 insertions(+), 33 deletions(-) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index fe558b7acac2..cc5e7ba8310d 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -65,8 +65,6 @@ def inputs(self) -> list[torch.Tensor]: class MatcherRMSNorm(MatcherCustomOp): def __init__(self, epsilon: float, enabled: Optional[bool] = None): if enabled is None: - # TODO either pass config to enabled or set it globally - # (global during pass init seems reasonable) enabled = RMSNorm.enabled() super().__init__(enabled) @@ -83,7 +81,6 @@ def forward_custom( self, input: torch.Tensor, weight: torch.Tensor, - residual: Optional[torch.Tensor] = None, ) -> torch.Tensor: result = torch.empty_like(input) _, result = auto_functionalized( @@ -100,28 +97,15 @@ def forward_native( self, input: torch.Tensor, weight: torch.Tensor, - residual: Optional[torch.Tensor] = None, ) -> torch.Tensor: - x = input.to(torch.float32) - if residual is not None: - x = x + residual - residual = x.to(self.model_dtype) - - variance = x.pow(2).mean(dim=-1, keepdim=True) - - x = x * torch.rsqrt(variance + self.epsilon) - x = x.to(self.model_dtype) - if weight is not None: - x = x * weight - - return x if residual is None else (x, residual) + return RMSNorm.forward_static( + input, self.epsilon, input.size(-1), self.model_dtype, weight + ) class MatcherFusedAddRMSNorm(MatcherCustomOp): def __init__(self, epsilon: float, enabled: Optional[bool] = None): if enabled is None: - # TODO either pass config to enabled or set it globally - # (global during pass init seems reasonable) enabled = RMSNorm.enabled() super().__init__(enabled) @@ -157,19 +141,9 @@ def forward_native( weight: torch.Tensor, residual: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - x = input.to(torch.float32) - if residual is not None: - x = x + residual - residual = x.to(self.model_dtype) - - variance = x.pow(2).mean(dim=-1, keepdim=True) - - x = x * torch.rsqrt(variance + self.epsilon) - x = x.to(self.model_dtype) - if weight is not None: - x = x * weight - - return x if residual is None else (x, residual) + return RMSNorm.forward_static( + input, self.epsilon, input.size(-1), self.model_dtype, weight, residual + ) class MatcherQuant: diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index b70ea33f2cd2..5b9d24c19a3c 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -187,12 +187,12 @@ def forward_static( x: torch.Tensor, variance_epsilon: float, hidden_size: int, + orig_dtype: torch.dtype, weight: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None, variance_size_override: Optional[int] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" - orig_dtype = x.dtype x = x.to(torch.float32) if residual is not None: # residual promoted f16->f32 automatically, @@ -239,6 +239,7 @@ def forward_native( x, self.variance_epsilon, self.hidden_size, + x.dtype, self.weight.data if self.has_weight else None, residual, self.variance_size_override, From 52f78ce6760f9e8754cf6cde299d0448b93411a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Sat, 11 Oct 2025 08:38:42 -0400 Subject: [PATCH 044/137] Add allreduce test to 2-gpu test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- .buildkite/test-pipeline.yaml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 85616de5b197..f02fa0c27373 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -812,6 +812,10 @@ steps: - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py - vllm/compilation/ + # can affect pattern matching + - vllm/model_executor/layers/layernorm.py + - vllm/model_executor/layers/activation.py + - vllm/model_executor/layers/quantization/input_quant_fp8.py commands: - nvidia-smi - python3 examples/offline_inference/basic/chat.py @@ -833,7 +837,6 @@ steps: - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py # Fusion - - pytest -v -s tests/compile/test_fusion_all_reduce.py - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern - pytest -v -s tests/kernels/moe/test_flashinfer.py - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py @@ -1090,7 +1093,7 @@ steps: - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 ##### H200 test ##### -- label: Distrubted Tests (H200) # optional +- label: Distributed Tests (H200) # optional gpu: h200 optional: true working_dir: "/vllm-workspace/" @@ -1110,6 +1113,7 @@ steps: commands: - pytest -v -s tests/distributed/test_context_parallel.py - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py + - pytest -v -s tests/compile/test_fusion_all_reduce.py - pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm ##### RL Integration Tests ##### From 1b1a63eb2e3086a94fd3a350531f8707dcb7be3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Sat, 11 Oct 2025 14:33:46 -0400 Subject: [PATCH 045/137] Fix e2e allreduce fusion test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusions_e2e.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index 6e4893cd0f66..f80bdb06bc68 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -69,7 +69,7 @@ class ModelBackendTestCase(NamedTuple): model_kwargs=dict(max_model_len=1024), backend=_Backend.TRITON_ATTN, attention_fusions=0, - allreduce_fusions=64, + allreduce_fusions=65, ), ] @@ -166,8 +166,7 @@ def test_attn_quant( # TODO(luka) test both in nightly -# TODO(luka) change to - -CUSTOM_OPS_RMS_NORM = ["+rms_norm"] # , "+rms_norm"] +CUSTOM_OPS_RMS_NORM = ["-rms_norm"] # , "+rms_norm"] def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: @@ -180,8 +179,11 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: "model_name, model_kwargs, backend, " "attention_fusions, allreduce_fusions, custom_ops", # Toggle RMSNorm and QuantFP8 for FP8 models - list(flat_product(MODELS_FP8, ["+quant_fp8,+rms_norm"])) - # custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM))) # TODO + list( + flat_product( + MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM) + ) + ) # TODO # Toggle RMSNorm for FP4 models and unquant models + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), ) @@ -245,17 +247,26 @@ def test_tp2_attn_quant_allreduce_rmsnorm( run_model( compilation_config, model_name, tensor_parallel_size=2, **model_kwargs ) - - assert f"Fused quant onto {attention_fusions} attention nodes" in log_holder.text, ( - log_holder.text + matches = re.findall( + r"\[compilation/fusion_attn.py:\d+] " + r"Fused quant onto (\d+) attention nodes", + log_holder.text, ) + assert len(matches) == 2, log_holder.text + + assert int(matches[0]) == attention_fusions + assert int(matches[1]) == attention_fusions matches = re.findall( - rf"\[collective_fusion.py:\d+] Replaced {allreduce_fusions} patterns", + r"\[compilation/collective_fusion.py:\d+] " + r"Replaced (\d+) patterns", log_holder.text, ) assert len(matches) == 2, log_holder.text + assert int(matches[0]) == allreduce_fusions + assert int(matches[1]) == allreduce_fusions + def run_model( compile_config: Union[int, CompilationConfig], model: str, **model_kwargs From 0d6e550bfe3032e374257b1fb495a6e7758b2fab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Sun, 12 Oct 2025 10:57:07 -0400 Subject: [PATCH 046/137] fix func test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/backend.py | 3 +- tests/compile/test_functionalization.py | 80 ++++++++++++++----------- 2 files changed, 48 insertions(+), 35 deletions(-) diff --git a/tests/compile/backend.py b/tests/compile/backend.py index a16ab9f15c9f..5d0e30ea5f39 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -54,7 +54,8 @@ def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]): self.custom_passes = list(passes) vllm_config = get_current_vllm_config() compile_config = vllm_config.compilation_config - self.inductor_config = compile_config.inductor_compile_config + # Deepcopy to allow multiple TestBackend instances to use the same VllmConfig + self.inductor_config = deepcopy(compile_config.inductor_compile_config) self.inductor_config["force_disable_caches"] = True self.inductor_config["post_grad_custom_post_pass"] = self.post_pass diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index ae17bc67b1fb..dd424d7f6ad0 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -11,7 +11,13 @@ from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass -from vllm.config import CompilationConfig, PassConfig, VllmConfig +from vllm.config import ( + CompilationConfig, + ModelConfig, + PassConfig, + VllmConfig, + set_current_vllm_config, +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape @@ -217,42 +223,48 @@ def ops_not_in_model(self): def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool): torch.set_default_device("cuda") - vllm_config = VllmConfig() - vllm_config.compilation_config = CompilationConfig( - pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True) + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=torch.bfloat16), + compilation_config=CompilationConfig( + custom_ops=["all"], + pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True), + ), ) - noop_pass = NoOpEliminationPass(vllm_config) - fusion_pass = RMSNormQuantFusionPass(vllm_config) - cleanup_pass = PostCleanupPass(vllm_config) - act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) - - passes = ( - [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass] - if do_fusion - else [noop_pass, cleanup_pass] - ) - func_pass = FixFunctionalizationPass(vllm_config) - backend_func = TestBackend(*passes, func_pass) - backend_no_func = TestBackend(*passes) + with set_current_vllm_config(vllm_config): + assert RMSNorm.enabled() + noop_pass = NoOpEliminationPass(vllm_config) + fusion_pass = RMSNormQuantFusionPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) + + passes = ( + [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass] + if do_fusion + else [noop_pass, cleanup_pass] + ) + func_pass = FixFunctionalizationPass(vllm_config) - model = model_class() - torch.compile(model, backend=backend_func)(*model.example_inputs()) - torch.compile(model, backend=backend_no_func)(*model.example_inputs()) + backend_func = TestBackend(*passes, func_pass) + backend_no_func = TestBackend(*passes) - # check if the functionalization pass is applied - for op in model.ops_in_model(do_fusion): - find_auto_fn(backend_no_func.graph_post_pass.nodes, op) - assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None + model = model_class() + torch.compile(model, backend=backend_func)(*model.example_inputs()) + torch.compile(model, backend=backend_no_func)(*model.example_inputs()) - # make sure the ops were all de-functionalized - found = dict() - for node in backend_func.graph_post_pass.nodes: + # check if the functionalization pass is applied for op in model.ops_in_model(do_fusion): - if is_func(node, op): - found[op] = True - for op in model.ops_not_in_model(): - if is_func(node, op): - found[op] = True - assert all(found[op] for op in model.ops_in_model(do_fusion)) - assert all(not found.get(op) for op in model.ops_not_in_model()) + find_auto_fn(backend_no_func.graph_post_pass.nodes, op) + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None + + # make sure the ops were all de-functionalized + found = dict() + for node in backend_func.graph_post_pass.nodes: + for op in model.ops_in_model(do_fusion): + if is_func(node, op): + found[op] = True + for op in model.ops_not_in_model(): + if is_func(node, op): + found[op] = True + assert all(found[op] for op in model.ops_in_model(do_fusion)) + assert all(not found.get(op) for op in model.ops_not_in_model()) From 26892dfa100a962b4502bf2e89f7610cc81b912d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Sun, 12 Oct 2025 11:03:35 -0400 Subject: [PATCH 047/137] fix pass manager test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_pass_manager.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/compile/test_pass_manager.py b/tests/compile/test_pass_manager.py index ac561d2e8f84..1c40c599f748 100644 --- a/tests/compile/test_pass_manager.py +++ b/tests/compile/test_pass_manager.py @@ -7,7 +7,7 @@ from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.compilation.pass_manager import PostGradPassManager -from vllm.config import VllmConfig +from vllm.config import ModelConfig, VllmConfig # dummy custom pass that doesn't inherit @@ -42,7 +42,8 @@ def __call__(self, graph: torch.fx.graph.Graph) -> None: ], ) def test_pass_manager_uuid(callable): - config = VllmConfig() + # Some passes need dtype to be set + config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16)) pass_manager = PostGradPassManager() pass_manager.configure(config) From 3547b877ad82ec5f1de52ca4a27aa186a119a50d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Sun, 12 Oct 2025 11:11:14 -0400 Subject: [PATCH 048/137] fix sequence parallelism test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_sequence_parallelism.py | 92 ++++++++++++---------- 1 file changed, 50 insertions(+), 42 deletions(-) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index afb31cb95be0..bca3932ffaf0 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -18,6 +18,7 @@ ModelConfig, PassConfig, VllmConfig, + set_current_vllm_config, ) from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( @@ -42,9 +43,7 @@ class TestModel(torch.nn.Module): - def __init__( - self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None - ): + def __init__(self, hidden_size=16, intermediate_size=32): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size @@ -266,68 +265,77 @@ def sequence_parallelism_pass_on_test_model( initialize_model_parallel(tensor_model_parallel_size=world_size) # configure vllm config for SequenceParallelismPass - vllm_config = VllmConfig() - vllm_config.compilation_config = CompilationConfig( + compilation_config = CompilationConfig( pass_config=PassConfig( enable_sequence_parallelism=True, enable_fusion=enable_fusion, enable_noop=True, ) ) # NoOp needed for fusion - vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) + device_config = DeviceConfig(device=torch.device("cuda")) # this is a fake model name to construct the model config # in the vllm_config, it's not really used. model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" - vllm_config.model_config = ModelConfig( + model_config = ModelConfig( model=model_name, trust_remote_code=True, dtype=dtype, seed=42 ) - noop_pass = NoOpEliminationPass(vllm_config) - sequence_parallelism_pass = SequenceParallelismPass(vllm_config) - func_pass = FixFunctionalizationPass(vllm_config) - cleanup_pass = PostCleanupPass(vllm_config) + vllm_config = VllmConfig( + model_config=model_config, + device_config=device_config, + compilation_config=compilation_config, + ) - passes_for_backend: list[VllmInductorPass] = [noop_pass, sequence_parallelism_pass] + with set_current_vllm_config(vllm_config): + noop_pass = NoOpEliminationPass(vllm_config) + sequence_parallelism_pass = SequenceParallelismPass(vllm_config) + func_pass = FixFunctionalizationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) - if enable_fusion: - fusion_pass = RMSNormQuantFusionPass(vllm_config) - passes_for_backend.append(fusion_pass) + passes_for_backend: list[VllmInductorPass] = [ + noop_pass, + sequence_parallelism_pass, + ] - passes_for_backend.append(cleanup_pass) + if enable_fusion: + fusion_pass = RMSNormQuantFusionPass(vllm_config) + passes_for_backend.append(fusion_pass) - backend_no_func = TestBackend(*passes_for_backend) - backend_func = TestBackend(*passes_for_backend, func_pass) + passes_for_backend.append(cleanup_pass) - model = test_model_cls(hidden_size, hidden_size * 2, vllm_config=vllm_config) + backend_no_func = TestBackend(*passes_for_backend) + backend_func = TestBackend(*passes_for_backend, func_pass) - hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) - residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) + model = test_model_cls(hidden_size, hidden_size * 2) - compiled_model_no_func = torch.compile(model, backend=backend_no_func) - compiled_model_no_func(hidden_states, residual) - compiled_model_func = torch.compile(model, backend=backend_func) - compiled_model_func(hidden_states, residual) + hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) + residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) - assert sequence_parallelism_pass.matched_count == 1 + compiled_model_no_func = torch.compile(model, backend=backend_no_func) + compiled_model_no_func(hidden_states, residual) + compiled_model_func = torch.compile(model, backend=backend_func) + compiled_model_func(hidden_states, residual) - # In pre-nodes, all reduce should be there, - # reduce scatter and all gather should not - backend_no_func.check_before_ops(model.ops_in_model_before()) + assert sequence_parallelism_pass.matched_count == 1 - # In post-nodes, reduce scatter and all gather should be there, - # all reduce should not - backend_no_func.check_after_ops(model.ops_in_model_after()) + # In pre-nodes, all reduce should be there, + # reduce scatter and all gather should not + backend_no_func.check_before_ops(model.ops_in_model_before()) - # check if the functionalization pass is applied - for op in model.ops_in_model(): - find_auto_fn(backend_no_func.graph_post_pass.nodes, op) - assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None + # In post-nodes, reduce scatter and all gather should be there, + # all reduce should not + backend_no_func.check_after_ops(model.ops_in_model_after()) - # make sure the ops were all de-functionalized - found = dict() - for node in backend_func.graph_post_pass.nodes: + # check if the functionalization pass is applied for op in model.ops_in_model(): - if is_func(node, op): - found[op] = True - assert all(found[op] for op in model.ops_in_model()) + find_auto_fn(backend_no_func.graph_post_pass.nodes, op) + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None + + # make sure the ops were all de-functionalized + found = dict() + for node in backend_func.graph_post_pass.nodes: + for op in model.ops_in_model(): + if is_func(node, op): + found[op] = True + assert all(found[op] for op in model.ops_in_model()) From af1ffa77d5606a30693a8c98d2333a50443ac5eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 01:54:18 -0400 Subject: [PATCH 049/137] PR review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/compilation/collective_fusion.py | 26 +++++--------------------- vllm/compilation/fusion.py | 4 ++-- vllm/compilation/fusion_attn.py | 4 ++-- vllm/compilation/matcher_utils.py | 21 ++++++++------------- 4 files changed, 17 insertions(+), 38 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index cc4f2152e1c5..d0e99497a372 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -25,7 +25,7 @@ from vllm.utils import direct_register_custom_op from .inductor_pass import enable_fake_mode -from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuant, MatcherRMSNorm +from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass FP8_DTYPE = current_platform.fp8_dtype() @@ -46,11 +46,8 @@ logger = init_logger(__name__) -ALLREDUCE_OP = torch.ops.vllm.all_reduce.default -RMS_OP = torch.ops._C.rms_norm.default -RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default -STATIC_FP8_QUANT_OP = torch.ops._C.static_scaled_fp8_quant.default -STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default +if hasattr(torch.ops._C, "scaled_fp4_quant"): + STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default class BasePattern: @@ -650,19 +647,6 @@ def get_trtllm_fused_allreduce_kwargs(self): } -class BaseAllReduceRMSNormPattern(BasePattern): - def __init__( - self, - epsilon: float, - dtype: torch.dtype, - device: str, - allreduce_params: FlashInferFusedAllReduceParams, - ): - super().__init__(dtype, device) - self.epsilon = epsilon - self.allreduce_params = allreduce_params - - class AllReduceRMSNormPattern(BasePattern): """ This pattern replaces the allreduce + rms norm (without residual) @@ -808,7 +792,7 @@ def __init__( self.allreduce_params = allreduce_params self.quant_dtype = torch.float8_e4m3fn self.rmsnorm_matcher = MatcherRMSNorm(epsilon) - self.quant_matcher = MatcherQuant(kFp8StaticTensorSym) + self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) def register(self, pm_pass: PatternMatcherPass): def get_inputs(): @@ -877,7 +861,7 @@ def __init__( self.quant_dtype = torch.float8_e4m3fn self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) - self.quant_matcher = MatcherQuant(kFp8StaticTensorSym) + self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) def register(self, pm_pass: PatternMatcherPass): def get_inputs(): diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 9ace7a8cf050..d6057e869ae0 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -24,7 +24,7 @@ from vllm.platforms import current_platform from .inductor_pass import enable_fake_mode -from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuant, MatcherRMSNorm +from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -102,7 +102,7 @@ def __init__(self, epsilon: float, key: FusedRMSQuantKey): if not key.fused_add else MatcherFusedAddRMSNorm(epsilon) ) - self.quant_matcher = MatcherQuant(key.quant) + self.quant_matcher = MatcherQuantFP8(key.quant) class RMSNormStaticQuantPattern(RMSNormQuantPattern): diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index 761acb35834b..2f3b0963d365 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -24,7 +24,7 @@ from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 from .fx_utils import is_func from .inductor_pass import enable_fake_mode -from .matcher_utils import MatcherQuant +from .matcher_utils import MatcherQuantFP8 from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -129,7 +129,7 @@ def __init__( dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric ) super().__init__(layer, quant_key, dtype) - self.quant_matcher = MatcherQuant(quant_key) + self.quant_matcher = MatcherQuantFP8(quant_key) def _register(self, pm_pass: PatternMatcherPass): def pattern( diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index cc5e7ba8310d..4b1c714fe4a4 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -146,22 +146,22 @@ def forward_native( ) -class MatcherQuant: +class MatcherQuantFP8(MatcherCustomOp): def __init__(self, quant_key: QuantKey, enabled: Optional[bool] = None): + if enabled is None: + enabled = QuantFP8.enabled() + + super().__init__(enabled) self.quant_key = quant_key assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}" self.QUANT_OP = QUANT_OPS[quant_key] + assert quant_key.dtype == current_platform.fp8_dtype(), ( + "Only QuantFP8 supported by" + ) assert quant_key.scale2 is None self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape) - if enabled is None: - # TODO either pass config to enabled or set it globally - # (global during pass init seems reasonable) - enabled = self.quant_fp8.enabled() - - self.forward = self.forward_custom if enabled else self.forward_native - def forward_custom( self, input: torch.Tensor, @@ -204,8 +204,3 @@ def make_scale(self, input: torch.Tensor): ) return torch.empty(scale_shape, device=input.device, dtype=torch.float32) - - def __call__( - self, input: torch.Tensor, scale: Optional[torch.Tensor] = None - ) -> tuple[torch.Tensor, torch.Tensor]: - return self.forward(input, scale) From b5f89e5d0291d3c9e6a2152f75b186eaa10862dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 02:29:06 -0400 Subject: [PATCH 050/137] Cleanup test_full_graph.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_full_graph.py | 66 +++++++++++++++++++------------- 1 file changed, 39 insertions(+), 27 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 9a955f4c9d81..fb511dd8f7ca 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import tempfile +from pathlib import Path from typing import Any import pytest @@ -21,27 +22,21 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None): ("facebook/opt-125m", {}), ( "neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", - { - "dtype": torch.float16, - }, + {"dtype": torch.float16}, ), ("meta-llama/Llama-3.2-1B-Instruct", {}), ] if all: - if not current_platform.has_device_capability((10, 0)): - # int8 removed on Blackwell - TEST_MODELS.extend( - [ - ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}), - ( - "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", - { - "dtype": torch.float16, - }, - ), - ] - ) + TEST_MODELS.extend( + [ + ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}), + ( + "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", + {"dtype": torch.float16}, + ), + ] + ) # TODO: figure out why this fails. if False and is_quant_method_supported("gguf"): # noqa: SIM223 @@ -95,6 +90,14 @@ def test_full_graph( model_kwargs: dict[str, Any], compilation_mode: int, ): + if ( + "w8a8" in model + or "w8w8" in model + and current_platform.has_device_capability((10, 0)) + ): + # int8 removed on Blackwell: + pytest.skip("int8 support removed on Blackwell") + with monkeypatch.context(): print(f"MODEL={model}") @@ -103,14 +106,14 @@ def test_full_graph( # TODO(luka) add other supported compilation config scenarios here @pytest.mark.parametrize( - "compilation_config, model_info", + "compilation_config, model, model_kwargs", [ # additional compile sizes, only some of the models ( CompilationConfig(mode=CompilationMode.VLLM_COMPILE, compile_sizes=[1, 2]), - model, + *model_info, ) - for model in models_list(all=False) + for model_info in models_list(all=False) ] + [ # RMSNorm + quant fusion, only 8-bit quant models @@ -120,18 +123,19 @@ def test_full_graph( custom_ops=["+rms_norm"], pass_config=PassConfig(enable_fusion=True, enable_noop=True), ), - model, + *model_info, ) - for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"]) + for model_info in models_list(keywords=["FP8-dynamic", "quantized.w8a8"]) ] + [ # Test depyf integration works ( CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - debug_dump_path=tempfile.gettempdir(), + debug_dump_path=Path(tempfile.gettempdir()), ), - ("facebook/opt-125m", {}), + "facebook/opt-125m", + {}, ), ] + [ @@ -145,9 +149,9 @@ def test_full_graph( cudagraph_mode=CUDAGraphMode.PIECEWISE, compile_sizes=[1, 2], ), - model, + *model_info, ) - for model in models_list(all=False) + for model_info in models_list(all=False) if is_torch_equal_or_newer("2.9.0.dev") ], ) @@ -155,14 +159,22 @@ def test_full_graph( @create_new_process_for_each_test() def test_custom_compile_config( compilation_config: CompilationConfig, - model_info: tuple[str, dict[str, Any]], + model: str, + model_kwargs: dict[str, Any], ): + if ( + "w8a8" in model + or "w8w8" in model + and current_platform.has_device_capability((10, 0)) + ): + # int8 removed on Blackwell: + pytest.skip("int8 support removed on Blackwell") + if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer( "2.9.0.dev" ): pytest.skip("inductor graph partition is only available in PyTorch 2.9+") - model, model_kwargs = model_info print(f"MODEL={model}") run_model(compilation_config, model, **model_kwargs) From f6429e416de6d5a0623a019f6afda7a9a5b2317a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 02:40:43 -0400 Subject: [PATCH 051/137] Cleanup test_fusion_attn.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_attn.py | 54 +++++++++++++++---------------- 1 file changed, 26 insertions(+), 28 deletions(-) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 375796952339..b6d8fc9e28dc 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -34,6 +34,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer +from vllm.utils.flashinfer import has_flashinfer from vllm.v1.kv_cache_interface import AttentionSpec FP8_DTYPE = current_platform.fp8_dtype() @@ -238,52 +239,41 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): ) -MODELS_FP8 = [] -MODELS_FP4 = [] -HEADS = [] -SPLIT_ATTENTION = [] +MODELS_FP8: list[tuple[str, type]] = [] +MODELS_FP4: list[tuple[str, type]] = [] +HEADS: list[tuple[int, int]] = [] +SPLIT_ATTENTION: list[bool] = [] BACKENDS_FP8: list[_Backend] = [] BACKENDS_FP4: list[_Backend] = [] if current_platform.is_cuda(): + HEADS = [(64, 8), (40, 8)] MODELS_FP8 = [ ( "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", TestAttentionFp8StaticQuantPatternModel, ) ] - HEADS = [(64, 8), (40, 8)] - SPLIT_ATTENTION = [False] - BACKENDS_FP8 = [_Backend.TRITON_ATTN] - - if current_platform.is_device_capability((10, 0)): - BACKENDS_FP8 += [_Backend.FLASHINFER] - BACKENDS_FP4 += [_Backend.FLASHINFER] - MODELS_FP4 += [ - ( - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", - TestAttentionNvfp4QuantPatternModel, - ) - ] + MODELS_FP4 = [ + ( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + TestAttentionNvfp4QuantPatternModel, + ) + ] + BACKENDS_FP8 = [_Backend.TRITON_ATTN, _Backend.FLASHINFER] + BACKENDS_FP4 = [_Backend.FLASHINFER] elif current_platform.is_rocm(): + HEADS = [(32, 8), (40, 8)] MODELS_FP8 = [ ("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel) ] - HEADS = [(32, 8), (40, 8)] - SPLIT_ATTENTION = [False, True] BACKENDS = [ - _Backend.TRITON_ATTN, _Backend.ROCM_AITER_UNIFIED_ATTN, _Backend.ROCM_ATTN, + _Backend.TRITON_ATTN, ] -# TODO(boyuan/luka): test inductor graph partition on rocm -if is_torch_equal_or_newer("2.9.0.dev") and current_platform.is_cuda(): - USE_INDUCTOR_GRAPH_PARTITION = [False, True] -else: - USE_INDUCTOR_GRAPH_PARTITION = [False] - @pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS) @pytest.mark.parametrize("head_size", [128]) @@ -298,7 +288,7 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): # quant_fp4 only has the custom impl + list(flat_product(BACKENDS_FP4, MODELS_FP4, [""])), ) -@pytest.mark.parametrize("use_inductor_graph_partition", USE_INDUCTOR_GRAPH_PARTITION) +@pytest.mark.parametrize("use_inductor_graph_partition", [True, False]) @pytest.mark.skipif( not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA" ) @@ -318,6 +308,14 @@ def test_attention_quant_pattern( caplog_vllm, ): """Test AttentionStaticQuantPattern fusion pass""" + if backend == _Backend.FLASHINFER and ( + not current_platform.is_device_capability((10, 0)) or not has_flashinfer() + ): + pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") + + # TODO(boyuan/luka): test inductor graph partition on rocm + if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("Inductor graph partition requires torch>=2.9") custom_ops_list = custom_ops.split(",") if custom_ops else [] @@ -435,7 +433,7 @@ def test_attention_quant_pattern( ) # access the underlying `AttnFusionPass` on the `LazyInitPass` - assert attn_pass.pass_.matched_count == 1 + assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) # Check attention ops in the graph before and after fusion attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass)) From 8a363d397227d55865e2e66159910aa51a3cd47d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 02:43:03 -0400 Subject: [PATCH 052/137] Slight improvement for E2E fusion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusions_e2e.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index b650b48c7d37..f55f3e1d2947 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -160,9 +160,13 @@ def test_attn_quant( with caplog_mp_spawn(logging.DEBUG) as log_holder: run_model(compilation_config, model_name, **model_kwargs) - assert f"Fused quant onto {attention_fusions} attention nodes" in log_holder.text, ( - log_holder.text + matches = re.findall( + r"\[compilation/fusion_attn.py:\d+] " + r"Fused quant onto (\d+) attention nodes", + log_holder.text, ) + assert len(matches) == 1, log_holder.text + assert int(matches[0]) == attention_fusions # TODO(luka) test both in nightly From 12a7c6d5d2f38b874e6933c6a049c873d5c4b441 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 03:00:52 -0400 Subject: [PATCH 053/137] Tests & docs for flat_product MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/utils.py | 10 +++++++--- tests/utils_/test_utils.py | 24 +++++++++++++++++++++++- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 54c51ed284fa..3042bacd4bb6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1265,7 +1265,11 @@ def check_answers( def flat_product(*iterables: Iterable[Any]): - """Flatten lists of tuples into cartesian product.""" + """ + Flatten lists of tuples of the cartesian product. + Useful when we want to avoid nested tuples to allow + test params to be unpacked directly from the decorator. + """ for element in itertools.product(*iterables): - normalized = (e if isinstance(e, tuple) else [e] for e in element) - yield list(itertools.chain(*normalized)) + normalized = (e if isinstance(e, tuple) else (e,) for e in element) + yield tuple(itertools.chain(*normalized)) diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index af5fc758f2c2..a14431681150 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -47,7 +47,7 @@ unique_filepath, ) -from ..utils import create_new_process_for_each_test, error_on_warning +from ..utils import create_new_process_for_each_test, error_on_warning, flat_product @pytest.mark.asyncio @@ -993,3 +993,25 @@ def test_unique_filepath(): paths.add(path) assert len(paths) == 10 assert len(list(Path(temp_dir).glob("*.txt"))) == 10 + + +def test_flat_product(): + # Check regular itertools.product behavior + result1 = list(flat_product([1, 2, 3], ["a", "b"])) + assert result1 == [ + (1, "a"), + (1, "b"), + (2, "a"), + (2, "b"), + (3, "a"), + (3, "b"), + ] + + # check that the tuples get flattened + result2 = list(flat_product([(1, 2), (3, 4)], ["a", "b"], [(5, 6)])) + assert result2 == [ + (1, 2, "a", 5, 6), + (1, 2, "b", 5, 6), + (3, 4, "a", 5, 6), + (3, 4, "b", 5, 6), + ] From 8ffb4744f86e003e08f2191292a7f2bfe731d13e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 03:25:26 -0400 Subject: [PATCH 054/137] Remove/fix TODOs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_attn.py | 2 +- tests/compile/test_fusions_e2e.py | 4 ++-- vllm/compilation/fusion.py | 16 +++++++++------- vllm/compilation/matcher_utils.py | 10 ++++++---- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index b6d8fc9e28dc..32b207ed0109 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -101,7 +101,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: num_blocks = batch_size * max_blocks backend = self.attn.backend - # TODO use get_kv_cache_stride_order + # TODO(luka) use get_kv_cache_stride_order # Create dummy KV cache for the selected backend if backend == _Backend.ROCM_ATTN: # k/v as 1st dimention diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index f55f3e1d2947..533e0c5867d3 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -90,7 +90,7 @@ class ModelBackendTestCase(NamedTuple): ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), - backend=_Backend.ROCM_AITER_FA, # TODO ROCM_AITER_UNIFIED_ATTN + backend=_Backend.ROCM_AITER_UNIFIED_ATTN, attention_fusions=32, ), ] @@ -187,7 +187,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: flat_product( MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM) ) - ) # TODO + ) # Toggle RMSNorm for FP4 models and unquant models + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), ) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index d6057e869ae0..d724eca03e82 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -9,7 +9,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from torch._ops import OpOverload -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, @@ -93,6 +93,8 @@ class RMSNormQuantPattern: def __init__(self, epsilon: float, key: FusedRMSQuantKey): self.epsilon = epsilon self.quant_dtype = key.quant.dtype + config = get_current_vllm_config() + self.model_dtype = config.model_config.dtype if config.model_config else None assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" self.FUSED_OP = FUSED_OPS[key] @@ -124,7 +126,7 @@ def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. - input = input.to(dtype=torch.float16) # TODO model dtype + input = input.to(dtype=self.model_dtype) result = torch.empty_like(input, dtype=self.quant_dtype) at = auto_functionalized( @@ -179,8 +181,8 @@ def replacement( ): # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. - input = input.to(dtype=torch.float16) # TODO model dtype - residual = residual.to(dtype=torch.float16) + input = input.to(dtype=self.model_dtype) + residual = residual.to(dtype=self.model_dtype) result = torch.empty_like(input, dtype=self.quant_dtype) at = auto_functionalized( @@ -235,7 +237,7 @@ def pattern(input: torch.Tensor, weight: torch.Tensor): def replacement(input: torch.Tensor, weight: torch.Tensor): # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. - input = input.to(dtype=torch.float16) # TODO model dtype + input = input.to(dtype=self.model_dtype) result = torch.empty_like(input, dtype=self.quant_dtype) scale = self.quant_matcher.make_scale(input) @@ -289,8 +291,8 @@ def replacement( ): # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. - input = input.to(dtype=torch.float16) # TODO model dtype - residual = residual.to(dtype=torch.float16) + input = input.to(dtype=self.model_dtype) + residual = residual.to(dtype=self.model_dtype) result = torch.empty_like(input, dtype=self.quant_dtype) scale = self.quant_matcher.make_scale(input) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 2fba5bd0cdbe..9b3854d9fb52 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -34,7 +34,9 @@ class MatcherCustomOp(ABC): def __init__(self, enabled: bool): - self.model_dtype = get_current_vllm_config().model_config.dtype + config = get_current_vllm_config() + self.model_dtype = config.model_config.dtype if config.model_config else None + self.device = config.device_config.device if config.device_config else None self.enabled = enabled self.forward = self.forward_custom if enabled else self.forward_native @@ -51,10 +53,10 @@ def __call__(self, *args, **kws): return self.forward(*args, **kws) def empty(self, *args, **kws): - return torch.empty(*args, dtype=self.model_dtype, device="cuda", **kws) + return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws) def empty_f32(self, *args, **kws): - return torch.empty(*args, dtype=torch.float32, device="cuda", **kws) + return torch.empty(*args, dtype=torch.float32, device=self.device, **kws) def inputs(self) -> list[torch.Tensor]: """Utility for inputs to the pattern""" @@ -166,7 +168,7 @@ def forward_custom( input: torch.Tensor, scale: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - # TODO: why does empty_like produce a permute but + # TODO(luka): why does empty_like produce a permute but # empty via shape doesn't? result = torch.empty( input.shape, device=input.device, dtype=self.quant_key.dtype From 2a6299c81b0e8de81161ff9efe4af719eed1b381 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 04:12:01 -0400 Subject: [PATCH 055/137] Fix e2e test patterns MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusions_e2e.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index 533e0c5867d3..a8ece68d4f0e 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -161,8 +161,7 @@ def test_attn_quant( run_model(compilation_config, model_name, **model_kwargs) matches = re.findall( - r"\[compilation/fusion_attn.py:\d+] " - r"Fused quant onto (\d+) attention nodes", + r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", log_holder.text, ) assert len(matches) == 1, log_holder.text @@ -252,8 +251,7 @@ def test_tp2_attn_quant_allreduce_rmsnorm( compilation_config, model_name, tensor_parallel_size=2, **model_kwargs ) matches = re.findall( - r"\[compilation/fusion_attn.py:\d+] " - r"Fused quant onto (\d+) attention nodes", + r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", log_holder.text, ) assert len(matches) == 2, log_holder.text @@ -262,8 +260,7 @@ def test_tp2_attn_quant_allreduce_rmsnorm( assert int(matches[1]) == attention_fusions matches = re.findall( - r"\[compilation/collective_fusion.py:\d+] " - r"Replaced (\d+) patterns", + r"collective_fusion.py:\d+] Replaced (\d+) patterns", log_holder.text, ) assert len(matches) == 2, log_holder.text From 465ce583f239e67e8518032d312debeab230cea4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 09:59:54 -0400 Subject: [PATCH 056/137] Update tests/compile/test_fusion.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 9f7e025a232e..aa37db8022d5 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -115,7 +115,7 @@ def ops_in_model_before_partial(self): return [RMS_OP, RMS_ADD_OP] if self.enable_rms_norm else [torch.ops.aten.rsqrt] -@pytest.mark.parametrize("dtype", [torch.float16]) # , torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) From bcd95b5f67a0a51580be925073a3c61a5fcb1655 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 11:54:47 -0400 Subject: [PATCH 057/137] Fix func test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- csrc/layernorm_quant_kernels.cu | 2 ++ tests/compile/test_functionalization.py | 34 +++++++++++-------------- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index 0fc462194fcd..f82ae50ae6dd 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -216,6 +216,8 @@ void fused_add_rms_norm_static_fp8_quant( double epsilon) { TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(residual.is_contiguous()); + TORCH_CHECK(residual.scalar_type() == input.scalar_type()); + TORCH_CHECK(weight.scalar_type() == input.scalar_type()); int hidden_size = input.size(-1); int input_stride = input.stride(-2); int num_tokens = input.numel() / hidden_size; diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index dd424d7f6ad0..11ae96e930da 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -54,8 +54,7 @@ def forward(self, x): return y def example_inputs(self, num_tokens=32, hidden_size=128): - dtype = torch.float16 if TEST_FP8 else torch.float32 - return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype),) + return (torch.rand(num_tokens, hidden_size * 2),) def ops_in_model(self, do_fusion): if TEST_FP8 and do_fusion: @@ -73,15 +72,11 @@ def __init__(self, hidden_size=16, intermediate_size=32): self.hidden_size = hidden_size self.intermediate_size = intermediate_size - dtype = torch.float16 if TEST_FP8 else torch.float32 - self.gate_proj = torch.nn.Parameter( - torch.empty((intermediate_size, hidden_size), dtype=dtype) + torch.empty((intermediate_size, hidden_size)) ) self.norm = RMSNorm(intermediate_size, 1e-05) - self.norm.weight = torch.nn.Parameter( - torch.ones(intermediate_size, dtype=dtype) - ) + self.norm.weight = torch.nn.Parameter(torch.ones(intermediate_size)) torch.nn.init.normal_(self.gate_proj, std=0.02) @@ -118,9 +113,8 @@ def forward(self, hidden_states, residual): return norm_output, residual_output def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16): - dtype = torch.float16 if TEST_FP8 else torch.float32 - hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) - residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) + hidden_states = torch.randn((batch_size * seq_len, hidden_size)) + residual = torch.randn((batch_size * seq_len, hidden_size)) return (hidden_states, residual) def ops_in_model(self, do_fusion): @@ -151,10 +145,9 @@ def forward(self, positions, q, k): return q_rotated, k_rotated def example_inputs(self, num_tokens=32, head_dim=64): - dtype = torch.float16 positions = torch.arange(num_tokens, dtype=torch.long) - q = torch.randn(num_tokens, head_dim, dtype=dtype) - k = torch.randn(num_tokens, head_dim, dtype=dtype) + q = torch.randn(num_tokens, head_dim) + k = torch.randn(num_tokens, head_dim) return (positions, q, k) def ops_in_model(self, do_fusion): @@ -172,7 +165,7 @@ def __init__(self, head_dim=64, num_heads=4, max_position=2048, base=10000): self.hidden_size = head_dim * num_heads self.qkv_proj = torch.nn.Linear( - self.hidden_size, self.hidden_size * 3, bias=False, dtype=torch.float16 + self.hidden_size, self.hidden_size * 3, bias=False ) self.rotary_emb = get_rope( @@ -196,10 +189,9 @@ def forward(self, positions, hidden_states): return qkv_updated def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4): - dtype = torch.float16 hidden_size = head_dim * num_heads positions = torch.arange(num_tokens, dtype=torch.long) - hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) + hidden_states = torch.randn(num_tokens, hidden_size) return (positions, hidden_states) def ops_in_model(self, do_fusion): @@ -217,14 +209,18 @@ def ops_not_in_model(self): ] +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("model_class", MODELS) @pytest.mark.parametrize("do_fusion", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA") -def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool): +def test_fix_functionalization( + model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype +): torch.set_default_device("cuda") + torch.set_default_dtype(dtype) vllm_config = VllmConfig( - model_config=ModelConfig(dtype=torch.bfloat16), + model_config=ModelConfig(dtype=dtype), compilation_config=CompilationConfig( custom_ops=["all"], pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True), From db2b1c76be4bc5ddb5d0a7b3f37b70f88ee2100f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 11:59:35 -0400 Subject: [PATCH 058/137] Smaller model for e2e fusion test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusions_e2e.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index a8ece68d4f0e..5d5750ca3715 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -37,11 +37,12 @@ class ModelBackendTestCase(NamedTuple): if current_platform.is_cuda(): MODELS_FP8 = [ ModelBackendTestCase( - model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + # Use smaller model for L40s in CI + model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", model_kwargs=dict(max_model_len=1024), backend=_Backend.TRITON_ATTN, - attention_fusions=48, - allreduce_fusions=96, + attention_fusions=32, + allreduce_fusions=65, ), ModelBackendTestCase( model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", From a3ebf0a2e47adf60eab806d604c27de4795847c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 12:09:48 -0400 Subject: [PATCH 059/137] fix fp8 quant tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/kernels/quant_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 9d11a7ef6413..34ce91585520 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -103,7 +103,7 @@ def ref_dynamic_per_tensor_fp8_quant( .clamp(fp8_traits_min, fp8_traits_max) .to(FP8_DTYPE) ) - return ref_out, ref_scale.view((1,)) + return ref_out, ref_scale.view((1, 1)) def native_w8a8_block_matmul( From 3943257943e9a5aa8161190f581ae0592778db0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 12:11:29 -0400 Subject: [PATCH 060/137] Restore original torch.Parameter behavior in RMSNorm MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/model_executor/layers/layernorm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 46a5dec14327..1e5703f4368c 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -170,7 +170,9 @@ def __init__( ) weight_dtype = dtype or torch.get_default_dtype() self.has_weight = has_weight - self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype)) + self.weight = torch.ones(hidden_size, dtype=weight_dtype) + if self.has_weight: + self.weight = nn.Parameter(self.weight) if current_platform.is_rocm(): self.rocm_norm_func = dispatch_rocm_rmsnorm_func( From 532cbcf134e688bb960357706ed508afc15a17de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 12:56:07 -0400 Subject: [PATCH 061/137] Add comment to test_logger MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/test_logger.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_logger.py b/tests/test_logger.py index f1c31c245475..01672358902f 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -503,6 +503,7 @@ def test_streaming_complete_logs_full_text_content(): assert call_args[5] == "streaming_complete" +# Add vllm prefix to make sure logs go through the vllm logger test_logger = init_logger("vllm.test_logger") From 7e6f5b3f85763bcc1774647251ffbb521cb350f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 13:06:19 -0400 Subject: [PATCH 062/137] add flat_product example MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/utils.py b/tests/utils.py index 3042bacd4bb6..9aed55b7258b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1269,6 +1269,15 @@ def flat_product(*iterables: Iterable[Any]): Flatten lists of tuples of the cartesian product. Useful when we want to avoid nested tuples to allow test params to be unpacked directly from the decorator. + + Example: + flat_product([(1, 2), (3, 4)], ["a", "b"]) -> + [ + (1, 2, "a"), + (1, 2, "b"), + (3, 4, "a"), + (3, 4, "b"), + ] """ for element in itertools.product(*iterables): normalized = (e if isinstance(e, tuple) else (e,) for e in element) From 24f1298435681914f58e9f25d211a226583b9a77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 13:08:13 -0400 Subject: [PATCH 063/137] PR comments: cleanup fusion passes, & matching MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/compilation/collective_fusion.py | 42 ++++++++++----------------- vllm/compilation/fusion.py | 2 -- vllm/compilation/matcher_utils.py | 4 +-- 3 files changed, 17 insertions(+), 31 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 599a30f72c8f..c1ed058ded70 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -673,10 +673,10 @@ def __init__( self.rmsnorm_matcher = MatcherRMSNorm(epsilon) def get_inputs(self): - input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) - weight = torch.empty([4], device=self.device, dtype=self.dtype) + input, weight = self.rmsnorm_matcher.inputs() - return [input, weight] + # input goes through allreduce first, always 16-bit + return [input.to(self.dtype), weight] def register(self, pm_pass: PatternMatcherPass): def pattern(input: torch.Tensor, weight: torch.Tensor): @@ -728,14 +728,10 @@ def __init__( self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) def get_inputs(self): - input = torch.empty([4, 4], device=self.device, dtype=self.dtype) - residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) - weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) - return [ - residual, - input, - weight, - ] + input, residual, weight = self.rmsnorm_matcher.inputs() + + # input goes through allreduce first, always 16-bit + return [residual, input.to(self.dtype), weight] def register(self, pm_pass: PatternMatcherPass): def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor): @@ -802,10 +798,11 @@ def __init__( def register(self, pm_pass: PatternMatcherPass): def get_inputs(): - input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype) - weight = torch.empty([4], device=self.device, dtype=self.dtype) - scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) - return [input, weight, scale] + input, weight = self.rmsnorm_matcher.inputs() + _, scale = self.quant_matcher.inputs() + + # input goes through allreduce first, always 16-bit + return [input.to(self.dtype), weight, scale] def pattern( input: torch.Tensor, @@ -871,18 +868,11 @@ def __init__( def register(self, pm_pass: PatternMatcherPass): def get_inputs(): - input = torch.empty([4, 4], device=self.device, dtype=self.dtype) + input, residual, weight = self.rmsnorm_matcher.inputs() + _, scale = self.quant_matcher.inputs() - residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) - weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) - scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) - - return [ - residual, - input, - weight, - scale, - ] + # input goes through allreduce first, always 16-bit + return [residual, input.to(self.dtype), weight, scale] def pattern( residual: torch.Tensor, diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index d724eca03e82..606874cc1034 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -182,7 +182,6 @@ def replacement( # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. input = input.to(dtype=self.model_dtype) - residual = residual.to(dtype=self.model_dtype) result = torch.empty_like(input, dtype=self.quant_dtype) at = auto_functionalized( @@ -292,7 +291,6 @@ def replacement( # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. input = input.to(dtype=self.model_dtype) - residual = residual.to(dtype=self.model_dtype) result = torch.empty_like(input, dtype=self.quant_dtype) scale = self.quant_matcher.make_scale(input) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 9b3854d9fb52..16d1d86d2b3e 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -73,9 +73,7 @@ def __init__(self, epsilon: float, enabled: bool | None = None): def inputs(self): input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) - weight = self.empty( - 16, - ) + weight = self.empty(16) return [input, weight] def forward_custom( From de7405b851d909dd8bb0241c81ca9c59bf5001bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 13:08:57 -0400 Subject: [PATCH 064/137] PR comments: add _custom_op suffix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion.py | 26 ++++++++++++++----------- tests/compile/test_fusion_all_reduce.py | 20 +++++++++---------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index aa37db8022d5..8c388f13002f 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -71,8 +71,8 @@ def __init__( act_quant_group_shape=group_shape, ) - self.enable_rms_norm = self.norm[0].enabled() - self.enable_quant_fp8 = self.fp8_linear.quant_fp8.enabled() + self.enable_rms_norm_custom_op = self.norm[0].enabled() + self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() def forward(self, x): # avoid having graph input be an arg to a pattern directly @@ -107,12 +107,16 @@ def ops_in_model_after(self): def ops_in_model_before(self): return ( [QUANT_OPS[self.quant_key]] - if self.enable_quant_fp8 + if self.enable_quant_fp8_custom_op else [torch.ops.aten.reciprocal] ) def ops_in_model_before_partial(self): - return [RMS_OP, RMS_ADD_OP] if self.enable_rms_norm else [torch.ops.aten.rsqrt] + return ( + [RMS_OP, RMS_ADD_OP] + if self.enable_rms_norm_custom_op + else [torch.ops.aten.rsqrt] + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -120,8 +124,8 @@ def ops_in_model_before_partial(self): @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("static", [True, False]) -@pytest.mark.parametrize("enable_rms_norm", [True, False]) -@pytest.mark.parametrize("enable_quant_fp8", [True, False]) +@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) +@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False]) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. @pytest.mark.parametrize( @@ -136,8 +140,8 @@ def test_fusion_rmsnorm_quant( num_tokens, eps, static, - enable_rms_norm, - enable_quant_fp8, + enable_rms_norm_custom_op, + enable_quant_fp8_custom_op, cuda_force_torch, ): torch.set_default_device("cuda") @@ -146,9 +150,9 @@ def test_fusion_rmsnorm_quant( maybe_create_device_identity() # needed for certain non-cutlass fp8 paths custom_ops = [] - if enable_rms_norm: + if enable_rms_norm_custom_op: custom_ops.append("+rms_norm") - if enable_quant_fp8: + if enable_quant_fp8_custom_op: custom_ops.append("+quant_fp8") vllm_config = VllmConfig( model_config=ModelConfig(dtype=dtype), @@ -195,7 +199,7 @@ def test_fusion_rmsnorm_quant( # there's a risk that the fused add doesn't get included in the # replacement and only the rms part gets fused with quant. # Hence, we check only 2 add nodes are left (final fused rmsnorm add). - if not enable_rms_norm: + if not enable_rms_norm_custom_op: n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g)) # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each) assert n_add_nodes(backend.graph_pre_pass) == 7 diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 4e6ed4446e4c..7688ba3d1b6c 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -194,7 +194,7 @@ def ops_in_model_before(self): @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( - "test_model, enable_quant_fp8", + "test_model, enable_quant_fp8_custom_op", [ (TestAllReduceRMSNormModel, False), (TestAllReduceRMSNormStaticQuantFP8Model, True), @@ -206,7 +206,7 @@ def ops_in_model_before(self): @pytest.mark.parametrize("seq_len", [8]) @pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("enable_rms_norm", [True, False]) +@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @pytest.mark.skipif( not find_spec("flashinfer") @@ -220,8 +220,8 @@ def test_all_reduce_fusion_pass_replace( seq_len: int, hidden_size: int, dtype: torch.dtype, - enable_rms_norm, - enable_quant_fp8, + enable_rms_norm_custom_op, + enable_quant_fp8_custom_op, ): num_processes = 2 if ( @@ -243,8 +243,8 @@ def run_torch_spawn(fn, nprocs): seq_len, hidden_size, dtype, - enable_rms_norm, - enable_quant_fp8, + enable_rms_norm_custom_op, + enable_quant_fp8_custom_op, ), nprocs=nprocs, ) @@ -260,8 +260,8 @@ def all_reduce_fusion_pass_on_test_model( seq_len: int, hidden_size: int, dtype: torch.dtype, - enable_rms_norm, - enable_quant_fp8, + enable_rms_norm_custom_op, + enable_quant_fp8_custom_op, ): current_platform.seed_everything(0) @@ -284,9 +284,9 @@ def all_reduce_fusion_pass_on_test_model( initialize_model_parallel(tensor_model_parallel_size=world_size) custom_ops = [] - if enable_rms_norm: + if enable_rms_norm_custom_op: custom_ops.append("+rms_norm") - if enable_quant_fp8: + if enable_quant_fp8_custom_op: custom_ops.append("+quant_fp8") vllm_config = VllmConfig( From 6253d5bd143a1975213462e7d6c4f8d3a2e1fef7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 13:18:03 -0400 Subject: [PATCH 065/137] Add e2e to L40 distributed, move tests to start of B200 distributed MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- .buildkite/test-pipeline.yaml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9d98c5adf6ae..29cce6b398e0 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -951,6 +951,7 @@ steps: - vllm/v1/worker/ - tests/compile/test_basic_correctness.py - tests/compile/test_wrapper.py + - tests/compile/test_fusions_e2e.py - tests/distributed/ - tests/entrypoints/llm/test_collective_rpc.py - tests/v1/distributed @@ -964,6 +965,7 @@ steps: - pytest -v -s entrypoints/llm/test_collective_rpc.py - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py + - pytest -v -s ./compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' - pytest -v -s distributed/test_sequence_parallel.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown @@ -1122,10 +1124,10 @@ steps: working_dir: "/vllm-workspace/" num_gpus: 2 commands: - - pytest -v -s tests/distributed/test_context_parallel.py - - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py - pytest -v -s tests/compile/test_fusion_all_reduce.py - pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm + - pytest -v -s tests/distributed/test_context_parallel.py + - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py ##### RL Integration Tests ##### - label: Prime-RL Integration Test # 15min From 876ef22e1e2921ed84b615a55edb442f627b42b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 18:43:48 -0400 Subject: [PATCH 066/137] Fix tests, PR feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion.py | 17 +++++++++++------ tests/compile/test_sequence_parallelism.py | 7 +++---- vllm/compilation/fusion.py | 6 +++--- vllm/compilation/matcher_utils.py | 11 ++++++++--- 4 files changed, 25 insertions(+), 16 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 8c388f13002f..aa0728e39c94 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -169,24 +169,29 @@ def test_fusion_rmsnorm_quant( cleanup_pass = PostCleanupPass(vllm_config) backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) + backend2 = TestBackend(noop_pass, cleanup_pass) model = TestModel(hidden_size, eps, static, cuda_force_torch) # First dimension dynamic x = torch.rand(num_tokens, hidden_size) torch._dynamo.mark_dynamic(x, 0) - result = model(x) + model_fused = torch.compile(model, backend=backend) + result_fused = model_fused(x) - model2 = torch.compile(model, backend=backend) - result2 = model2(x) + model_unfused = torch.compile(model, backend=backend2) + result_unfused = model_unfused(x) - # Higher tol for dynamic bfloat16 - if dtype == torch.float16 or static: + if enable_rms_norm_custom_op and static: + ATOL, RTOL = (1e-5, 1e-5) # up to 1e-8 close + elif dtype == torch.float16: ATOL, RTOL = (2e-3, 2e-3) + elif static: + ATOL, RTOL = (5e-3, 5e-3) else: ATOL, RTOL = (1e-2, 1e-2) - torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL) assert fusion_pass.matched_count == 3 backend.check_before_ops(model.ops_in_model_before()) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index ba2178964ff3..24bc88d44f38 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -18,6 +18,7 @@ ModelConfig, PassConfig, VllmConfig, + get_current_vllm_config, set_current_vllm_config, ) from vllm.distributed import tensor_model_parallel_all_reduce @@ -94,13 +95,11 @@ def ops_in_model(self): class TestQuantModel(torch.nn.Module): - def __init__( - self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None - ): + def __init__(self, hidden_size=16, intermediate_size=32): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.vllm_config = vllm_config + self.vllm_config = get_current_vllm_config() self.gate_proj = torch.nn.Parameter( torch.empty((intermediate_size, hidden_size)), requires_grad=False ) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 606874cc1034..98703ed5f007 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -33,7 +33,7 @@ def empty_bf16(*args, **kwargs): - return torch.empty(*args, **kwargs, dtype=torch.float16, device="cuda") + return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") def empty_fp32(*args, **kwargs): @@ -144,7 +144,7 @@ def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): inputs = [ # input, weight *self.rmsnorm_matcher.inputs(), - empty_fp32(1, 1), # scale + self.quant_matcher.inputs()[1], # scale ] pattern(*inputs) @@ -200,7 +200,7 @@ def replacement( inputs = [ # input, weight, residual *self.rmsnorm_matcher.inputs(), - empty_fp32(1, 1), # scale + self.quant_matcher.inputs()[1], # scale ] pm.register_replacement( diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 16d1d86d2b3e..8be4de96ebbf 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -112,9 +112,7 @@ def __init__(self, epsilon: float, enabled: bool | None = None): def inputs(self): input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) - weight = self.empty( - 16, - ) + weight = self.empty(16) residual = self.empty(5, 16) return [input, weight, residual] @@ -203,3 +201,10 @@ def make_scale(self, input: torch.Tensor): ) return torch.empty(scale_shape, device=input.device, dtype=torch.float32) + + def inputs(self) -> list[torch.Tensor]: + input = self.empty(5, 16) + if self.quant_key.scale.static: + return [input, self.empty_f32(1, 1)] + + return [input] From e99a7598260d1a3f33cb85be27eff177a4b28dee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 19:20:47 -0400 Subject: [PATCH 067/137] Break up B200 tests, move allreduce to H200 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- .buildkite/test-pipeline.yaml | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 29cce6b398e0..df5b474bb729 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -808,7 +808,7 @@ steps: # Whisper needs spawn method to avoid deadlock - VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper -- label: Blackwell Test # 48 min +- label: Blackwell Test # TODO min timeout_in_minutes: 70 working_dir: "/vllm-workspace/" gpu: b200 @@ -822,11 +822,6 @@ steps: - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py - - vllm/compilation/ - # can affect pattern matching - - vllm/model_executor/layers/layernorm.py - - vllm/model_executor/layers/activation.py - - vllm/model_executor/layers/quantization/input_quant_fp8.py commands: - nvidia-smi - python3 examples/offline_inference/basic/chat.py @@ -847,10 +842,27 @@ steps: - pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py - # Fusion - - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern - pytest -v -s tests/kernels/moe/test_flashinfer.py + +- label: Blackwell Fusion Tests # TODO min + timeout_in_minutes: 70 + working_dir: "/vllm-workspace/" + gpu: b200 + source_file_dependencies: + - csrc/quantization/fp4/ + - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py + - vllm/v1/attention/backends/flashinfer.py + - vllm/compilation/ + # can affect pattern matching + - vllm/model_executor/layers/layernorm.py + - vllm/model_executor/layers/activation.py + - vllm/model_executor/layers/quantization/input_quant_fp8.py + commands: + - nvidia-smi + - pytest -v -s tests/compile/test_fusion_attn.py - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py + # this runner has 2 GPUs available even though num_gpus=2 is not set + - pytest -v -s tests/compile/test_fusion_all_reduce.py - pytest -v -s tests/compile/test_fusions_e2e.py - label: Blackwell GPT-OSS Eval @@ -951,7 +963,6 @@ steps: - vllm/v1/worker/ - tests/compile/test_basic_correctness.py - tests/compile/test_wrapper.py - - tests/compile/test_fusions_e2e.py - tests/distributed/ - tests/entrypoints/llm/test_collective_rpc.py - tests/v1/distributed @@ -965,7 +976,6 @@ steps: - pytest -v -s entrypoints/llm/test_collective_rpc.py - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py - - pytest -v -s ./compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' - pytest -v -s distributed/test_sequence_parallel.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown @@ -1114,6 +1124,8 @@ steps: commands: - pytest -v -s tests/compile/test_async_tp.py - pytest -v -s tests/compile/test_sequence_parallelism.py + - pytest -v -s tests/compile/test_fusion_all_reduce.py + - pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm - pytest -v -s tests/distributed/test_context_parallel.py - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 @@ -1124,8 +1136,6 @@ steps: working_dir: "/vllm-workspace/" num_gpus: 2 commands: - - pytest -v -s tests/compile/test_fusion_all_reduce.py - - pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm - pytest -v -s tests/distributed/test_context_parallel.py - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py From ae581e176787d4fab88438330a8b93add1f5ce48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 20:30:02 -0400 Subject: [PATCH 068/137] Fix attention fusion test numerics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_attn.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 691d9256d7be..2498c2d58a31 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -368,8 +368,9 @@ def test_attention_quant_pattern( forward_ctx = get_forward_context() forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size) - # Run model directly without compilation and fusion - result_unfused = model_unfused(q, k, v) + # Run model directly without fusion + # Still compile so query QuantFP8 has closer numerics + result_unfused = torch.compile(model_unfused, fullgraph=True)(q, k, v) # Run model with attn fusion enabled vllm_config.compilation_config.pass_config = PassConfig( From c03b29bfb520735b4460dd5c4bf1b8ee3d5743cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 20:31:11 -0400 Subject: [PATCH 069/137] Remove inductor graph partition from unit test (included in e2e tests) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_attn.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 2498c2d58a31..fecb1e2e918f 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -35,7 +35,6 @@ ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform -from vllm.utils import is_torch_equal_or_newer from vllm.utils.flashinfer import has_flashinfer from vllm.v1.kv_cache_interface import AttentionSpec @@ -290,7 +289,6 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): # quant_fp4 only has the custom impl + list(flat_product(BACKENDS_FP4, MODELS_FP4, [""])), ) -@pytest.mark.parametrize("use_inductor_graph_partition", [True, False]) @pytest.mark.skipif( not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA" ) @@ -305,7 +303,6 @@ def test_attention_quant_pattern( model_name: str, model_class: type[AttentionQuantPatternModel], backend: _Backend, - use_inductor_graph_partition: bool, dist_init, ): """Test AttentionStaticQuantPattern fusion pass""" @@ -314,10 +311,6 @@ def test_attention_quant_pattern( ): pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") - # TODO(boyuan/luka): test inductor graph partition on rocm - if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): - pytest.skip("Inductor graph partition requires torch>=2.9") - custom_ops_list = custom_ops.split(",") if custom_ops else [] device = torch.device("cuda:0") @@ -333,7 +326,6 @@ def test_attention_quant_pattern( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, custom_ops=custom_ops_list, - use_inductor_graph_partition=use_inductor_graph_partition, ), cache_config=CacheConfig(cache_dtype="fp8"), ) From d2e0489da1200b387c09c7867b465e5a18c2275e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 16 Oct 2025 00:31:15 -0400 Subject: [PATCH 070/137] Relax tolerance for L40 fusion test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index aa0728e39c94..4e42094f73e6 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -182,9 +182,7 @@ def test_fusion_rmsnorm_quant( model_unfused = torch.compile(model, backend=backend2) result_unfused = model_unfused(x) - if enable_rms_norm_custom_op and static: - ATOL, RTOL = (1e-5, 1e-5) # up to 1e-8 close - elif dtype == torch.float16: + if dtype == torch.float16: ATOL, RTOL = (2e-3, 2e-3) elif static: ATOL, RTOL = (5e-3, 5e-3) From d4fe977cdfe5419afd297c90fae45171ac004fb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 16 Oct 2025 00:54:25 -0400 Subject: [PATCH 071/137] Fix NamedTuple MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusions_e2e.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index 5d5750ca3715..7399abaec542 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -6,11 +6,10 @@ import itertools import logging from collections.abc import Iterable -from typing import Any +from typing import Any, NamedTuple import pytest import regex as re -from black.cache import NamedTuple from tests.v1.attention.utils import _Backend from vllm import LLM, SamplingParams From 6319e39757784acb19d84cec9a89791dc8939c4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 16 Oct 2025 00:58:59 -0400 Subject: [PATCH 072/137] Update test durations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- .buildkite/test-pipeline.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 236d6d4c8be5..238b6ef98bf2 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -808,8 +808,8 @@ steps: # Whisper needs spawn method to avoid deadlock - VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper -- label: Blackwell Test # TODO min - timeout_in_minutes: 70 +- label: Blackwell Test # 21 min + timeout_in_minutes: 30 working_dir: "/vllm-workspace/" gpu: b200 # optional: true @@ -844,8 +844,8 @@ steps: - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py - pytest -v -s tests/kernels/moe/test_flashinfer.py -- label: Blackwell Fusion Tests # TODO min - timeout_in_minutes: 70 +- label: Blackwell Fusion Tests # 30 min + timeout_in_minutes: 40 working_dir: "/vllm-workspace/" gpu: b200 source_file_dependencies: From e34d36d2e13b25d066bd14a111d9cb3db998d34f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 16 Oct 2025 09:33:16 -0400 Subject: [PATCH 073/137] More tweaking of precision MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 4e42094f73e6..286f2276367a 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -184,8 +184,6 @@ def test_fusion_rmsnorm_quant( if dtype == torch.float16: ATOL, RTOL = (2e-3, 2e-3) - elif static: - ATOL, RTOL = (5e-3, 5e-3) else: ATOL, RTOL = (1e-2, 1e-2) From f72ee4385c014ec68b96c0b72a130f2b6bd94ccd Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 4 Sep 2025 04:29:31 -0700 Subject: [PATCH 074/137] Split original pr Signed-off-by: ilmarkov --- .../kernels/benchmark_fused_collective.py | 1270 +++++++++++++++++ tests/compile/test_fusion_all_reduce.py | 2 +- vllm/compilation/collective_fusion.py | 98 +- vllm/config/compilation.py | 64 +- 4 files changed, 1381 insertions(+), 53 deletions(-) create mode 100644 benchmarks/kernels/benchmark_fused_collective.py diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py new file mode 100644 index 000000000000..ea78875c62cf --- /dev/null +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -0,0 +1,1270 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Benchmark for FlashInfer fused collective operations vs standard operations. + +This benchmark compares: +1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant) +2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations + +Usage with torchrun: + torchrun --nproc_per_node=2 benchmark_fused_collective.py + +""" + +import argparse +import itertools +import os +import time +from typing import Optional + +import torch # type: ignore +import torch.distributed as dist # type: ignore + +from vllm.distributed import ( + get_tp_group, + tensor_model_parallel_all_reduce, +) +from vllm.distributed.parallel_state import ( + graph_capture, + init_distributed_environment, + initialize_model_parallel, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm # noqa +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 # noqa +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape # noqa +from vllm.platforms import current_platform # noqa + +RMS_NORM_OP = torch.ops._C.rms_norm +FUSED_ADD_RMS_NORM_OP = torch.ops._C.fused_add_rms_norm +RMS_NORM_STATIC_FP8_QUANT_OP = torch.ops._C.rms_norm_static_fp8_quant +FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP = ( + torch.ops._C.fused_add_rms_norm_static_fp8_quant +) +SCALED_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant + +logger = init_logger(__name__) + +# Try to import FlashInfer +try: + import flashinfer.comm as flashinfer_comm # type: ignore + + if not hasattr(flashinfer_comm, "trtllm_allreduce_fusion"): + flashinfer_comm = None + logger.warning( + "FlashInfer comm module found but missing trtllm_allreduce_fusion" + ) +except ImportError: + flashinfer_comm = None + logger.warning("FlashInfer not found, only benchmarking standard operations") + +# Constants +FP8_DTYPE = current_platform.fp8_dtype() +MiB = 1024 * 1024 + +# FlashInfer max sizes per world size +# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes +# use --disable-oneshot to disable oneshot mode for very large input sizes +_FI_MAX_SIZES = { + 2: 64 * MiB, # 64MB + 4: 64 * MiB, # 64MB + 8: 64 * MiB, # 64MB +} + +# Global workspace tensor for FlashInfer +_FI_WORKSPACE_TENSOR = None + + +def setup_flashinfer_workspace( + world_size: int, + rank: int, + hidden_dim: int, + max_token_num: int, + use_fp32_lamport: bool = False, +): + """Setup FlashInfer workspace for fused allreduce operations.""" + global _FI_WORKSPACE_TENSOR + + if flashinfer_comm is None: + return None, None + + if world_size not in _FI_MAX_SIZES: + logger.warning("FlashInfer not supported for world size %s", world_size) + return None, None + + try: + # Create IPC workspace + ipc_handles, workspace_tensor = ( + flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + tp_rank=rank, + tp_size=world_size, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + group=get_tp_group().device_group, + use_fp32_lamport=use_fp32_lamport, + ) + ) + + _FI_WORKSPACE_TENSOR = workspace_tensor + return ipc_handles, workspace_tensor + except Exception as e: + logger.error("Failed to setup FlashInfer workspace: %s", e) + return None, None + + +def cleanup_flashinfer_workspace(ipc_handles): + """Cleanup FlashInfer workspace.""" + if flashinfer_comm is None or ipc_handles is None: + return + + try: + group = get_tp_group().device_group + flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group) + except Exception as e: + logger.error("Failed to cleanup FlashInfer workspace: %s", e) + + +class FlashInferFusedAllReduceParams: + """Parameters for FlashInfer fused allreduce operations.""" + + def __init__( + self, + rank: int, + world_size: int, + use_fp32_lamport: bool = False, + max_token_num: int = 1024, + ): + self.rank = rank + self.world_size = world_size + self.use_fp32_lamport = use_fp32_lamport + self.trigger_completion_at_end = True + self.launch_with_pdl = True + self.fp32_acc = True + self.max_token_num = max_token_num + + def get_trtllm_fused_allreduce_kwargs(self): + return { + "world_rank": self.rank, + "world_size": self.world_size, + "launch_with_pdl": self.launch_with_pdl, + "trigger_completion_at_end": self.trigger_completion_at_end, + "fp32_acc": self.fp32_acc, + } + + +def flashinfer_fused_allreduce_rmsnorm( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + allreduce_params: "FlashInferFusedAllReduceParams", + use_oneshot: bool, + norm_out: Optional[torch.Tensor] = None, +): + """FlashInfer fused allreduce + rmsnorm operation.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, + allreduce_out=None, + quant_out=None, + scale_out=None, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4_, + scale_factor=None, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +def flashinfer_fused_allreduce_rmsnorm_fp8_quant( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + scale_factor: torch.Tensor, + allreduce_params: FlashInferFusedAllReduceParams, + use_oneshot: bool = True, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, +): + """FlashInfer fused allreduce + rmsnorm + FP8 quantization.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant, + allreduce_out=None, + quant_out=quant_out, + scale_out=None, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=scale_factor, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +def flashinfer_fused_allreduce_rmsnorm_fp4_quant( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + input_global_scale: torch.Tensor, + allreduce_params: FlashInferFusedAllReduceParams, + quant_out: torch.Tensor, + use_oneshot: bool, + output_scale: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, +): + """FlashInfer fused allreduce + rmsnorm + FP4 quantization.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant, + allreduce_out=None, + quant_out=quant_out, + scale_out=output_scale, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=input_global_scale, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +def standard_allreduce_rmsnorm( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + norm_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm operations.""" + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + # Then RMS norm + if residual is not None: + # Fused add + RMS norm + FUSED_ADD_RMS_NORM_OP(allreduce_out, residual, rms_gamma, rms_eps) + else: + # Just RMS norm + if norm_out is None: + norm_out = torch.empty_like(allreduce_out) + RMS_NORM_OP(norm_out, allreduce_out, rms_gamma, rms_eps) + + +def standard_allreduce_rmsnorm_fp8_quant( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + scale_factor: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm + FP8 quantization.""" + if quant_out is None: + quant_out = torch.empty_like(input_tensor, dtype=FP8_DTYPE) + + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + + # Then fused RMS norm + FP8 quantization + if residual is not None: + FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP( + quant_out, allreduce_out, residual, rms_gamma, scale_factor, rms_eps + ) + return quant_out, residual + else: + RMS_NORM_STATIC_FP8_QUANT_OP( + quant_out, allreduce_out, rms_gamma, scale_factor, rms_eps + ) + return quant_out + + +def standard_allreduce_rmsnorm_fp4_quant( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + input_global_scale: torch.Tensor, + quant_out: torch.Tensor, + output_scale: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm + FP4 quantization.""" + + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + + # Then RMS norm + if residual is not None: + FUSED_ADD_RMS_NORM_OP(allreduce_out, residual, rms_gamma, rms_eps) + quant_input = allreduce_out + residual_out = residual + else: + if norm_out is None: + norm_out = torch.empty_like(allreduce_out) + RMS_NORM_OP(norm_out, allreduce_out, rms_gamma, rms_eps) + quant_input = norm_out + residual_out = allreduce_out + + # Finally FP4 quantization + SCALED_FP4_QUANT_OP(quant_out, quant_input, output_scale, input_global_scale) + if residual is not None: + return quant_out, residual_out, output_scale + else: + return quant_out, norm_out + + +def standard_allreduce_rmsnorm_native( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + norm_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm operations using native RMSNorm forward.""" + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + # Apply native RMSNorm + if residual is not None: + result = rmsnorm_layer.forward_native(allreduce_out, residual) + return result # Returns (norm_out, residual_out) + else: + result = rmsnorm_layer.forward_native(allreduce_out) + return result # Returns norm_out + + +def standard_allreduce_rmsnorm_fp8_quant_native( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + quant_fp8_layer: QuantFP8, + scale_factor: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm + FP8 quantization using native implementations.""" + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + + # Apply native RMSNorm + if residual is not None: + norm_out, residual_out = rmsnorm_layer.forward_native(allreduce_out, residual) + else: + norm_out = rmsnorm_layer.forward_native(allreduce_out) + residual_out = allreduce_out + + # Apply native FP8 quantization + quant_out, _ = quant_fp8_layer.forward_native(norm_out, scale=scale_factor) + + if residual is not None: + return quant_out, residual_out + else: + return quant_out + + +def standard_allreduce_rmsnorm_fp4_quant_native( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + input_global_scale: torch.Tensor, + quant_out: torch.Tensor, + output_scale: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm + FP4 quantization using native RMSNorm.""" + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + + # Apply native RMSNorm + if residual is not None: + norm_out, residual_out = rmsnorm_layer.forward_native(allreduce_out, residual) + quant_input = norm_out + else: + norm_out = rmsnorm_layer.forward_native(allreduce_out) + quant_input = norm_out + residual_out = allreduce_out + + # Apply FP4 quantization (still using fused CUDA op as there's no native FP4) + SCALED_FP4_QUANT_OP(quant_out, quant_input, output_scale, input_global_scale) + + if residual is not None: + return quant_out, residual_out, output_scale + else: + return quant_out, norm_out + + +# Compiled versions of native functions +@torch.compile +def standard_allreduce_rmsnorm_native_compiled( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + norm_out: Optional[torch.Tensor] = None, +): + """Compiled version of standard allreduce + rmsnorm.""" + return standard_allreduce_rmsnorm_native( + input_tensor, residual, rmsnorm_layer, norm_out + ) + + +@torch.compile +def standard_allreduce_rmsnorm_fp8_quant_native_compiled( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + quant_fp8_layer: QuantFP8, + scale_factor: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, +): + """Compiled version of standard allreduce + rmsnorm + FP8 quantization.""" + return standard_allreduce_rmsnorm_fp8_quant_native( + input_tensor, + residual, + rmsnorm_layer, + quant_fp8_layer, + scale_factor, + norm_out, + quant_out, + ) + + +@torch.compile +def standard_allreduce_rmsnorm_fp4_quant_native_compiled( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + input_global_scale: torch.Tensor, + quant_out: torch.Tensor, + output_scale: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, +): + """Compiled version of standard allreduce + rmsnorm + FP4 quantization.""" + return standard_allreduce_rmsnorm_fp4_quant_native( + input_tensor, + residual, + rmsnorm_layer, + input_global_scale, + quant_out, + output_scale, + norm_out, + ) + + +def create_test_tensors( + seq_len: int, hidden_dim: int, dtype: torch.dtype, use_residual: bool = True +): + """Create test tensors for benchmarking.""" + input_tensor = torch.randn(seq_len, hidden_dim, dtype=dtype) + residual = ( + torch.randn_like(input_tensor) + if use_residual + else torch.zeros_like(input_tensor) + ) + rms_gamma = torch.ones(hidden_dim, dtype=dtype) + norm_out = None if use_residual else torch.empty_like(input_tensor) + + # Quantization scales + scale_fp8 = torch.tensor(1.0, dtype=torch.float32) + scale_fp4 = torch.tensor(1.0, dtype=torch.float32) + quant_out_fp8 = torch.empty_like(input_tensor, dtype=FP8_DTYPE) + # Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks) + fp4_quant_out = torch.empty((seq_len, hidden_dim // 2), dtype=torch.uint8) + fp4_output_scale = torch.empty((128, 4), dtype=torch.int32) + + return ( + input_tensor, + norm_out, + residual, + rms_gamma, + scale_fp8, + quant_out_fp8, + scale_fp4, + fp4_quant_out, + fp4_output_scale, + ) + + +def benchmark_operation( + operation_func, *args, warmup: int = 5, trials: int = 20, **kwargs +): + """Benchmark a single operation using CUDA graphs.""" + # Warmup before graph capture + for _ in range(warmup): + operation_func(*args, **kwargs) + torch.cuda.synchronize() + + # Create CUDA graph + graph = torch.cuda.CUDAGraph() + num_op_per_cudagraph = 10 + + # Use vLLM's graph_capture to make tensor_model_parallel_all_reduce graph-safe + device = torch.device(f"cuda:{torch.cuda.current_device()}") + with graph_capture(device=device), torch.cuda.graph(graph): + for _ in range(num_op_per_cudagraph): + operation_func(*args, **kwargs) + + # Graph warmup + torch.cuda.synchronize() + for _ in range(warmup): + graph.replay() + + # Benchmark with CUDA graph + torch.cuda.synchronize() + start_time = time.perf_counter() + + for _ in range(trials // num_op_per_cudagraph): + # operation_func(*args, **kwargs) + graph.replay() + + torch.cuda.synchronize() + end_time = time.perf_counter() + + avg_time_ms = ((end_time - start_time) / trials) * 1000 + return avg_time_ms + + +def run_benchmarks( + seq_len: int, + hidden_dim: int, + dtype: torch.dtype, + use_residual: bool, + allreduce_params: Optional[FlashInferFusedAllReduceParams], + quant_mode: str = "all", + disable_oneshot: bool = False, +): + """Run all benchmarks for given configuration. + + Args: + quant_mode: "none", "fp8_only", "fp4_only", or "all" + """ + ( + input_tensor, + norm_out, + residual, + rms_gamma, + scale_fp8, + quant_out_fp8, + scale_fp4, + fp4_quant_out, + fp4_output_scale, + ) = create_test_tensors(seq_len, hidden_dim, dtype, use_residual) + + rms_eps = 1e-6 + results = {} + + # Create RMSNorm and QuantFP8 layers once for native benchmarks + rmsnorm_layer = RMSNorm(hidden_dim, eps=rms_eps, dtype=dtype) + rmsnorm_layer.weight.data = rms_gamma + quant_fp8_layer = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) + + if quant_mode in ["all", "none"]: + # Standard AllReduce + RMSNorm + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + ) + results["standard_allreduce_rmsnorm"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm failed: %s", e) + results["standard_allreduce_rmsnorm"] = float("inf") + + # Standard AllReduce + RMSNorm Native Compiled + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_native_compiled, + input_tensor, + residual=residual, + rmsnorm_layer=rmsnorm_layer, + norm_out=norm_out, + ) + results["standard_allreduce_rmsnorm_native_compiled"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e) + results["standard_allreduce_rmsnorm_native_compiled"] = float("inf") + + # FlashInfer Fused AllReduce + RMSNorm Oneshot + if flashinfer_comm is not None and allreduce_params is not None: + try: + if not disable_oneshot: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + allreduce_params=allreduce_params, + use_oneshot=True, + ) + results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = time_ms + except Exception as e: + logger.error("FlashInfer Fused AllReduce+RMSNorm Oneshot failed: %s", e) + results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = float("inf") + + # FlashInfer Fused AllReduce + RMSNorm Two-shot + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + allreduce_params=allreduce_params, + use_oneshot=False, + ) + results["flashinfer_fused_allreduce_rmsnorm_twoshot"] = time_ms + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm Two-shot failed: %s", e + ) + results["flashinfer_fused_allreduce_rmsnorm_twoshot"] = float("inf") + + if quant_mode in ["all", "fp8_only"]: + # Standard AllReduce + RMSNorm + FP8 Quant + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp8_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_fp8, + quant_out=quant_out_fp8, + ) + results["standard_allreduce_rmsnorm_fp8_quant"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e) + results["standard_allreduce_rmsnorm_fp8_quant"] = float("inf") + + # Standard AllReduce + RMSNorm + FP8 Quant Native Compiled + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp8_quant_native_compiled, + input_tensor, + residual=residual, + rmsnorm_layer=rmsnorm_layer, + quant_fp8_layer=quant_fp8_layer, + scale_factor=scale_fp8, + norm_out=norm_out, + quant_out=quant_out_fp8, + ) + results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s", e) + results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot + if flashinfer_comm is not None and allreduce_params is not None: + try: + if not disable_oneshot: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp8_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_fp8, + quant_out=quant_out_fp8, + allreduce_params=allreduce_params, + use_oneshot=True, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = float( + "inf" + ) + # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Two-shot + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp8_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_fp8, + quant_out=quant_out_fp8, + allreduce_params=allreduce_params, + use_oneshot=False, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_twoshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP8 Two-shot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_twoshot"] = float( + "inf" + ) + + if quant_mode in ["all", "fp4_only"]: + # Standard AllReduce + RMSNorm + FP4 Quant + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp4_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + input_global_scale=scale_fp4, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + ) + results["standard_allreduce_rmsnorm_fp4_quant"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP4 failed: %s", e) + results["standard_allreduce_rmsnorm_fp4_quant"] = float("inf") + + # Standard AllReduce + RMSNorm + FP4 Quant Native Compiled + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp4_quant_native_compiled, + input_tensor, + residual=residual, + rmsnorm_layer=rmsnorm_layer, + input_global_scale=scale_fp4, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + norm_out=norm_out, + ) + results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s", e) + results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot + if flashinfer_comm is not None and allreduce_params is not None: + try: + if not disable_oneshot: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp4_quant, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + input_global_scale=scale_fp4, + allreduce_params=allreduce_params, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + use_oneshot=True, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot + if flashinfer_comm is not None and allreduce_params is not None: + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp4_quant, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + input_global_scale=scale_fp4, + allreduce_params=allreduce_params, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + use_oneshot=False, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = float( + "inf" + ) + + return results + + +def prepare_results_with_speedups(results_dict): + """Prepare results with speedup calculations based on dynamic baseline selection.""" + prepared_results = [] + + # Determine the fastest baseline for each operation type + def get_fastest_baseline(op_name, results_dict): + """Get the fastest baseline between standard and native_compiled versions.""" + if "fp8_quant" in op_name: + candidates = [ + "standard_allreduce_rmsnorm_fp8_quant", + "standard_allreduce_rmsnorm_fp8_quant_native_compiled", + ] + elif "fp4_quant" in op_name: + candidates = [ + "standard_allreduce_rmsnorm_fp4_quant", + "standard_allreduce_rmsnorm_fp4_quant_native_compiled", + ] + else: + candidates = [ + "standard_allreduce_rmsnorm", + "standard_allreduce_rmsnorm_native_compiled", + ] + + # Find the fastest among available candidates + fastest_time = float("inf") + fastest_baseline = None + + for candidate in candidates: + if ( + candidate in results_dict + and results_dict[candidate] != float("inf") + and results_dict[candidate] < fastest_time + ): + fastest_time = results_dict[candidate] + fastest_baseline = candidate + + return fastest_baseline + + # Create dynamic baseline mapping + dynamic_baseline_mapping = {} + for op_name in results_dict: + if ( + op_name.startswith("flashinfer_") + or op_name.startswith("standard_") + and not op_name.endswith("_native_compiled") + ): + dynamic_baseline_mapping[op_name] = get_fastest_baseline( + op_name, results_dict + ) + + for op_name, time_ms in results_dict.items(): + if time_ms == float("inf"): + speedup_str = "FAILED" + time_str = "FAILED" + else: + time_str = f"{time_ms:.3f}" + # Find the appropriate baseline for this operation + baseline_op = dynamic_baseline_mapping.get(op_name) + if baseline_op and baseline_op in results_dict: + baseline_time = results_dict[baseline_op] + if baseline_time != float("inf") and baseline_time > 0: + speedup = baseline_time / time_ms + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = "N/A" + else: + # For baseline operations, determine if this is the fastest baseline + if op_name.endswith("_native_compiled") or ( + op_name.startswith("standard_") + and not op_name.endswith("_native_compiled") + ): + fastest_baseline = get_fastest_baseline(op_name, results_dict) + if fastest_baseline == op_name: + speedup_str = "baseline" + else: + if fastest_baseline and fastest_baseline in results_dict: + baseline_time = results_dict[fastest_baseline] + if baseline_time != float("inf") and baseline_time > 0: + speedup = baseline_time / time_ms + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = "N/A" + else: + speedup_str = "N/A" + else: + speedup_str = "N/A" + + prepared_results.append( + { + "operation": op_name, + "time_ms": time_ms, + "time_str": time_str, + "speedup_str": speedup_str, + } + ) + + return prepared_results + + +def print_results(results_dict, seq_len, hidden_dim, dtype, use_residual, quant_mode): + """Print benchmark results in a formatted table.""" + print(f"\n{'=' * 80}") + print(f"Results: seq_len={seq_len}, hidden_dim={hidden_dim}") + print( + f"dtype={dtype}, residual={'yes' if use_residual else 'no'}, " + f"quant_mode={quant_mode}" + ) + print(f"{'=' * 80}") + print(f"{'Operation':<50} {'Time (ms)':<12} {'Speedup':<10}") + print(f"{'-' * 80}") + + # Prepare results with speedup calculations + prepared_results = prepare_results_with_speedups(results_dict) + + for result in prepared_results: + if result["time_ms"] == float("inf"): + time_display = result["time_str"] + else: + time_display = f"{result['time_ms']:.3f}" + + print( + f"{result['operation']:<50} {time_display:<12} {result['speedup_str']:<10}" + ) + + +def format_results_markdown( + all_results: list[dict], world_size: int, args: argparse.Namespace +) -> str: + """Format all benchmark results as markdown.""" + markdown = f"""# FlashInfer Fused Collective Operations Benchmark Results + +**World Size:** {world_size} +**Hidden Dimension:** {args.hidden_dim} +**Warmup Iterations:** {args.warmup} +**Benchmark Trials:** {args.trials} +**Quantization Mode:** {all_results[0]["quant_mode"] if all_results else "N/A"} + +--- + +""" + + for result in all_results: + seq_len = result["seq_len"] + dtype = result["dtype"] + use_residual = result["use_residual"] + results_dict = result["results"] + + residual_str = "with residual" if use_residual else "no residual" + + markdown += f""" +## Configuration: seq_len={seq_len}, dtype={dtype}, {residual_str} + +| Operation | Time (ms) | Speedup | +|-----------|-----------|---------| +""" + + # Prepare results with speedup calculations + prepared_results = prepare_results_with_speedups(results_dict) + + for result in prepared_results: + # Format operation name for better readability + formatted_op_name = result["operation"].replace("_", " ").title() + markdown += f"| {formatted_op_name} | {result['time_str']} |" + markdown += f"{result['speedup_str']} |\n" + + markdown += "\n" + + return markdown + + +def save_results_to_file( + all_results: list[dict], world_size: int, args: argparse.Namespace, rank: int +): + """Save benchmark results to markdown file (only on rank 0).""" + if rank != 0: + return + + if not all_results: + logger.warning("No results to save") + return + + output_path = args.output_file + + try: + markdown_content = format_results_markdown(all_results, world_size, args) + + with open(output_path, "w") as f: + f.write(markdown_content) + + except Exception as e: + logger.error("Failed to save results to file: %s", e) + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark fused collective operations" + ) + parser.add_argument( + "--seq-lens", + type=int, + nargs="+", + default=[128, 512, 1024, 2048], + help="Sequence lengths to test", + ) + parser.add_argument( + "--hidden-dim", type=int, default=8192, help="Hidden dimension size" + ) + parser.add_argument( + "--dtypes", + type=str, + nargs="+", + default=["bfloat16"], + choices=["float16", "bfloat16", "float32"], + help="Data types to test", + ) + parser.add_argument( + "--no-residual", + action="store_true", + help="Skip residual connection tests", + ) + + # Quantization mode options (mutually exclusive with --no-quant) + quant_group = parser.add_mutually_exclusive_group() + quant_group.add_argument( + "--no-quant", action="store_true", help="Skip all quantization tests" + ) + quant_group.add_argument( + "--quant-fp8", action="store_true", help="Only run FP8 quantization tests" + ) + quant_group.add_argument( + "--quant-fp4", action="store_true", help="Only run FP4 quantization tests" + ) + quant_group.add_argument( + "--quant-all", + action="store_true", + help="Run all quantization tests (default)", + ) + + parser.add_argument( + "--disable-oneshot", + action="store_true", + help="Disable oneshot mode for FlashInfer operations", + ) + parser.add_argument( + "--warmup", type=int, default=5, help="Number of warmup iterations" + ) + parser.add_argument( + "--trials", type=int, default=20, help="Number of benchmark trials" + ) + parser.add_argument( + "--output-file", + type=str, + help="""Output file path for markdown results + (default: benchmark_results_.md) + """, + ) + + args = parser.parse_args() + + # Check if running with torchrun (required for collective operations) + if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ: + raise RuntimeError( + "Must run with torchrun for distributed benchmarking. " + "Example: torchrun --nproc_per_node=2 benchmark_fused_collective.py" + ) + + # Initialize distributed environment + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Validate world size (must be > 1 for collective operations) + if world_size <= 1: + raise ValueError( + "World size must be > 1 for collective operations benchmarking. " + f"Current world size: {world_size}. Use torchrun with --nproc_per_node > 1." + ) + + # Determine quantization mode + if args.no_quant: + quant_mode = "none" + elif args.quant_fp8: + quant_mode = "fp8_only" + elif args.quant_fp4: + quant_mode = "fp4_only" + else: # args.quant_all or default + quant_mode = "all" + + if rank == 0: + logger.info("Running benchmark with world_size=%s, rank=%s", world_size, rank) + logger.info("Quantization mode: %s", quant_mode) + if flashinfer_comm is not None: + oneshot_status = "enabled" if not args.disable_oneshot else "disabled" + logger.info( + "FlashInfer available - will benchmark fused operations (oneshot: %s)", + oneshot_status, + ) + else: + logger.info( + "FlashInfer not available - only benchmarking standard operations" + ) + + # Convert dtype strings to torch dtypes + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + } + dtypes = [dtype_map[dt] for dt in args.dtypes] + + # Test configurations + residual_options = [True] if not args.no_residual else [False] + if not args.no_residual: + residual_options.append(False) + + configs = list(itertools.product(args.seq_lens, dtypes, residual_options)) + + # Setup FlashInfer workspace if available + ipc_handles = None + allreduce_params = None + + if flashinfer_comm is not None: + # Use the largest hidden dimension for workspace setup + max_num_token = _FI_MAX_SIZES.get(world_size) // ( + args.hidden_dim * world_size * 2 + ) + + ipc_handles, workspace_tensor = setup_flashinfer_workspace( + world_size, rank, args.hidden_dim, max_num_token + ) + + if workspace_tensor is not None: + allreduce_params = FlashInferFusedAllReduceParams( + rank=rank, + world_size=world_size, + max_token_num=max_num_token, + ) + + # Collect all results for markdown export + all_results = [] + + try: + # Run benchmarks + for seq_len, dtype, use_residual in configs: + if rank == 0: + logger.info( + "\nTesting: seq_len=%s, hidden_dim=%s, dtype=%s, residual=%s", + seq_len, + args.hidden_dim, + dtype, + use_residual, + ) + + results = run_benchmarks( + seq_len, + args.hidden_dim, + dtype, + use_residual, + allreduce_params, + quant_mode=quant_mode, + disable_oneshot=args.disable_oneshot, + ) + + # Store results for markdown export + if rank == 0: + all_results.append( + { + "seq_len": seq_len, + "hidden_dim": args.hidden_dim, + "dtype": str(dtype).replace("torch.", ""), + "use_residual": use_residual, + "quant_mode": quant_mode, + "results": results, + } + ) + + print_results( + results, + seq_len, + args.hidden_dim, + dtype, + use_residual, + quant_mode, + ) + + # Save results to markdown file + if args.output_file and rank == 0: + save_results_to_file(all_results, world_size, args, rank) + + finally: + # Cleanup + if ipc_handles is not None: + cleanup_flashinfer_workspace(ipc_handles) + + dist.barrier() + + +if __name__ == "__main__": + main() diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 7688ba3d1b6c..4798dbf1df1e 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -329,4 +329,4 @@ def all_reduce_fusion_pass_on_test_model( ) backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) backend.check_after_ops(model.ops_in_model_after()) - del all_reduce_fusion_pass + del all_reduce_fusion_pass \ No newline at end of file diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index c1ed058ded70..c99c63aedc2a 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -9,8 +9,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from torch.distributed._symmetric_memory import enable_symm_mem_for_group -import vllm.envs as envs -from vllm.config import VllmConfig +from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, @@ -454,31 +453,21 @@ def __call__(self, graph: fx.Graph): _FI_WORKSPACE_TENSOR = None MiB = 1024 * 1024 - # Max size of the input tensor per world size - # to use flashinfer fused allreduce - _FI_MAX_SIZES = { - 2: 64 * MiB, # 64MB - 4: MiB, # 1MB - 6: MiB // 2, # 512KB - 8: MiB // 2, # 512KB + # Max size of the input tensor per world size per device capability + # to use flashinfer one shot fused allreduce + _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES = { + "9.0": { + 2: 32 * MiB, # 32MB + 4: 2 * MiB, # 2MB + 8: 1 * MiB, # 1MB + }, + "10.0": { + 2: 32 * MiB, # 32MB + 4: 4 * MiB, # 4MB + 8: 1 * MiB, # 1MB + }, } - try: - _FI_MAX_SIZES.update( - { - int(k): int(float(v) * MiB) - for k, v in envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items() - } - ) - except Exception as e: - raise ValueError( - "Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: " + str(e) - ) from e - - # opt for a more conservative default value - # when world size is not in _FI_MAX_SIZES - _DEFAULT_FI_MAX_SIZE = MiB // 2 - def call_trtllm_fused_allreduce_norm( allreduce_in: torch.Tensor, residual: torch.Tensor, @@ -500,15 +489,22 @@ def call_trtllm_fused_allreduce_norm( num_tokens, hidden_size = allreduce_in.shape element_size = allreduce_in.element_size() current_tensor_size = num_tokens * hidden_size * element_size - max_fusion_size = max_token_num * hidden_size * element_size - use_flashinfer = current_tensor_size <= min( - _FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE), - max_fusion_size, - ) - if use_flashinfer: - assert _FI_WORKSPACE_TENSOR is not None, ( - "Flashinfer must be enabled when using flashinfer" - ) + max_tensor_size = max_token_num * hidden_size * element_size + + if current_tensor_size <= max_tensor_size: + device_capability = current_platform.get_device_capability( + ).as_version_str() + # Get one shot input size limit for the current world size + # for the current device capability + max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES. \ + get(device_capability, {}). \ + get(world_size, None) + # Use one shot if no max size is specified + use_oneshot = max_one_shot_size is None or \ + current_tensor_size <= max_one_shot_size + + assert (_FI_WORKSPACE_TENSOR is not None + ), "Flashinfer must be enabled when using flashinfer" if norm_out is None: norm_out = allreduce_in residual_out = residual @@ -532,7 +528,7 @@ def call_trtllm_fused_allreduce_norm( hidden_dim=allreduce_in.shape[-1], workspace_ptrs=_FI_WORKSPACE_TENSOR, launch_with_pdl=launch_with_pdl, - use_oneshot=True, + use_oneshot=use_oneshot, trigger_completion_at_end=trigger_completion_at_end, fp32_acc=fp32_acc, pattern_code=pattern_code, @@ -545,7 +541,8 @@ def call_trtllm_fused_allreduce_norm( ) else: allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) - if scale_factor is not None and scale_out is None and fuse_rms_quant: + if (scale_factor is not None and scale_out is None and + fuse_rms_quant): # Do fused rms norm static fp8 quant fused op if norm_out is None: torch.ops._C.fused_add_rms_norm_static_fp8_quant( @@ -637,10 +634,9 @@ def __init__( self.trigger_completion_at_end = True self.launch_with_pdl = True self.fp32_acc = True - self.use_oneshot = False self.max_token_num = max_token_num self.fuse_rms_quant = fuse_rms_quant - + def get_trtllm_fused_allreduce_kwargs(self): return { "world_rank": self.rank, @@ -1096,7 +1092,6 @@ def replacement( pattern, replacement, get_inputs(), pm.fwd_only, pm_pass ) - class AllReduceFusionPass(VllmPatternMatcherPass): def __init__(self, config: VllmConfig): super().__init__(config) @@ -1119,23 +1114,27 @@ def __init__(self, config: VllmConfig): "skipping allreduce fusion pass" ) return - # Check if the world size is supported - if self.tp_size not in _FI_MAX_SIZES: + max_size = config.compilation_config.\ + pass_config.flashinfer_max_size(self.tp_size) + if max_size is None: + # Flashinfer doesn't support current world size logger.warning( "Flashinfer allreduce fusion is not supported for world size %s", self.tp_size, ) return - max_num_token = min( - _FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) - // (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)), - config.compilation_config.pass_config.fi_allreduce_fusion_max_token_num, - ) + element_size = 4 if use_fp32_lamport else 2 + max_token_num = (max_size // (self.hidden_dim * element_size)) + # take the min to save workspace size and we'll never use more + # than max_num_batched_tokens anyways + max_token_num = min(max_token_num, + config.scheduler_config.max_num_batched_tokens) + self.ipc_handles, workspace_tensor = ( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( tp_rank=rank, tp_size=self.tp_size, - max_token_num=max_num_token, + max_token_num=max_token_num, hidden_dim=self.hidden_dim, group=self.group, use_fp32_lamport=use_fp32_lamport, @@ -1148,11 +1147,10 @@ def __init__(self, config: VllmConfig): rank=rank, world_size=self.tp_size, use_fp32_lamport=use_fp32_lamport, - max_token_num=max_num_token, + max_token_num=max_token_num, # fuse rms norm static fp8 quant fused op # in fallback path, when we don't use flashinfer - fuse_rms_quant=config.compilation_config.pass_config.enable_fusion, - ) + fuse_rms_quant=config.compilation_config.pass_config.enable_fusion) self.register_patterns() self.dump_patterns(config, self.patterns) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index a34fb0bf920c..84bc5e19c74c 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -109,11 +109,66 @@ class PassConfig: """Whether to enable async TP.""" enable_fi_allreduce_fusion: bool = False """Whether to enable flashinfer allreduce fusion.""" - fi_allreduce_fusion_max_token_num: int = 16384 - """Max number of tokens to used in flashinfer allreduce fusion.""" + fi_allreduce_fusion_max_size_mb: dict[int, + float] = field(default_factory=dict) + """The thresholds of the communicated tensor sizes under which + vllm should use flashinfer fused allreduce. Specified as a + dictionary mapping each world size to the threshold in MB + { : } + Unspecified world sizes will fallback to + _FI_ALLREDUCE_MAX_INPUT_SIZES = { + "9.0": { + 2: 64 * MiB, # 64MB + 4: 2 * MiB, # 2MB + 8: 1 * MiB, # 1MB + }, + "10.0": { + 2: 64 * MiB, # 64MB + 4: 32 * MiB, # 32MB + 8: 1 * MiB, # 1MB + }, + }, where key is the device capability""" # TODO(luka) better pass enabling system. + def flashinfer_max_size(self, world_size: int) -> Optional[int]: + """ + Returns the max communication size in bytes for flashinfer + allreduce fusion for the given world size. Falls back to + conservative defaults if the world size is not specified in config. + """ + + # import here to avoid circular dependencies + from vllm.platforms import current_platform + MiB = 1024 * 1024 + + # Max size of the input tensor per world size per device capability + # to use flashinfer fused allreduce + _FI_ALLREDUCE_MAX_INPUT_SIZES = { + "9.0": { + 2: 64 * MiB, # 64MB + 4: 2 * MiB, # 2MB + 8: 1 * MiB, # 1MB + }, + "10.0": { + 2: 64 * MiB, # 64MB + 4: 32 * MiB, # 32MB + 8: 1 * MiB, # 1MB + }, + } + + device_capability = current_platform.get_device_capability( + ).as_version_str() + max_sizes = _FI_ALLREDUCE_MAX_INPUT_SIZES.get(device_capability, {}) + max_sizes.update({ + k: int(v * MiB) + for k, v in self.fi_allreduce_fusion_max_size_mb.items() + }) + if world_size not in max_sizes: + # FlashInfer doesn't support other world sizes + return None + return max_sizes[world_size] + def uuid(self): """ Produces a hash unique to the pass configuration. @@ -134,6 +189,11 @@ def __post_init__(self) -> None: "Fusion enabled but reshape elimination disabled. " "Attention + quant (fp8) fusion might not work" ) + if self.enable_fi_allreduce_fusion: + logger.warning_once( + "Fusion enabled but reshape elimination disabled. " + "Allreduce + rms norm + quant (fp8) fusion might not work" + ) @config From c4c0215874a0dec0981625c336e4449ee0e88e72 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Fri, 5 Sep 2025 05:58:33 -0700 Subject: [PATCH 075/137] Update bench Signed-off-by: ilmarkov --- .../kernels/benchmark_fused_collective.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py index ea78875c62cf..7f012af36a94 100644 --- a/benchmarks/kernels/benchmark_fused_collective.py +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -187,7 +187,7 @@ def flashinfer_fused_allreduce_rmsnorm( allreduce_out=None, quant_out=None, scale_out=None, - layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4_, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, scale_factor=None, use_oneshot=use_oneshot, **allreduce_params.get_trtllm_fused_allreduce_kwargs(), @@ -962,10 +962,15 @@ def get_fastest_baseline(op_name, results_dict): return prepared_results -def print_results(results_dict, seq_len, hidden_dim, dtype, use_residual, quant_mode): +def print_results( + results_dict, seq_len, hidden_dim, dtype, use_residual, quant_mode, input_size_mb +): """Print benchmark results in a formatted table.""" print(f"\n{'=' * 80}") - print(f"Results: seq_len={seq_len}, hidden_dim={hidden_dim}") + print( + f"Results: seq_len={seq_len}, hidden_dim={hidden_dim} " + f"(input size: {input_size_mb:.2f} MB)" + ) print( f"dtype={dtype}, residual={'yes' if use_residual else 'no'}, " f"quant_mode={quant_mode}" @@ -1009,11 +1014,12 @@ def format_results_markdown( dtype = result["dtype"] use_residual = result["use_residual"] results_dict = result["results"] - + input_size_mb = result["input_size_mb"] residual_str = "with residual" if use_residual else "no residual" markdown += f""" ## Configuration: seq_len={seq_len}, dtype={dtype}, {residual_str} +**Input Size:** {input_size_mb:.2f} MB | Operation | Time (ms) | Speedup | |-----------|-----------|---------| @@ -1234,6 +1240,10 @@ def main(): # Store results for markdown export if rank == 0: + # Calculate input size in MB + input_size_mb = ( + seq_len * args.hidden_dim * torch.finfo(dtype).bits + ) / (8 * 1024 * 1024) all_results.append( { "seq_len": seq_len, @@ -1241,6 +1251,7 @@ def main(): "dtype": str(dtype).replace("torch.", ""), "use_residual": use_residual, "quant_mode": quant_mode, + "input_size_mb": input_size_mb, "results": results, } ) @@ -1252,6 +1263,7 @@ def main(): dtype, use_residual, quant_mode, + input_size_mb, ) # Save results to markdown file From 309d79e8a41e1c1360adf8409ca6f38aa226a00c Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 8 Sep 2025 04:41:09 -0700 Subject: [PATCH 076/137] Update threshold configuration Signed-off-by: ilmarkov --- vllm/config/compilation.py | 59 ++++++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 25 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 84bc5e19c74c..ff3a092fe538 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -134,40 +134,24 @@ class PassConfig: def flashinfer_max_size(self, world_size: int) -> Optional[int]: """ Returns the max communication size in bytes for flashinfer - allreduce fusion for the given world size. Falls back to - conservative defaults if the world size is not specified in config. + allreduce fusion for the given world size. Returns None if world size + is not supported by configs as it's not supported by flashinfer. """ # import here to avoid circular dependencies from vllm.platforms import current_platform MiB = 1024 * 1024 - # Max size of the input tensor per world size per device capability - # to use flashinfer fused allreduce - _FI_ALLREDUCE_MAX_INPUT_SIZES = { - "9.0": { - 2: 64 * MiB, # 64MB - 4: 2 * MiB, # 2MB - 8: 1 * MiB, # 1MB - }, - "10.0": { - 2: 64 * MiB, # 64MB - 4: 32 * MiB, # 32MB - 8: 1 * MiB, # 1MB - }, - } - device_capability = current_platform.get_device_capability( ).as_version_str() - max_sizes = _FI_ALLREDUCE_MAX_INPUT_SIZES.get(device_capability, {}) - max_sizes.update({ + fi_allreduce_fusion_max_size_mb = \ + self.fi_allreduce_fusion_max_size_mb.get(device_capability, {}) + max_sizes = { k: int(v * MiB) - for k, v in self.fi_allreduce_fusion_max_size_mb.items() - }) - if world_size not in max_sizes: - # FlashInfer doesn't support other world sizes - return None - return max_sizes[world_size] + for k, v in fi_allreduce_fusion_max_size_mb.items() + } + # return None if world size is not supported by flashinfer + return max_sizes.get(world_size) def uuid(self): """ @@ -195,6 +179,31 @@ def __post_init__(self) -> None: "Allreduce + rms norm + quant (fp8) fusion might not work" ) + # import here to avoid circular dependencies + from vllm.platforms import current_platform + + # Default tuned max size of the input tensor + # per world size per device capability + # to use flashinfer fused allreduce + fi_allreduce_fusion_max_size_mb = { + "9.0": { + 2: 64, # 64MB + 4: 2, # 2MB + 8: 1, # 1MB + }, + "10.0": { + 2: 64, # 64MB + 4: 32, # 32MB + 8: 1, # 1MB + }, + } + device_capability = current_platform.get_device_capability( + ).as_version_str() + + max_sizes = fi_allreduce_fusion_max_size_mb.get(device_capability, {}) + max_sizes.update(self.fi_allreduce_fusion_max_size_mb) + self.fi_allreduce_fusion_max_size_mb[device_capability] = max_sizes + @config @dataclass From afcfd73f5c1b5cf1bfc17d0537ae526ad102eea2 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 8 Sep 2025 05:01:47 -0700 Subject: [PATCH 077/137] Move all_reduce from custom op in fused_moe Signed-off-by: ilmarkov --- vllm/model_executor/layers/fused_moe/layer.py | 82 +++++++++---------- 1 file changed, 40 insertions(+), 42 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index de4ed58e0cf4..4bd7ab12f9c0 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -2105,33 +2105,59 @@ def forward_native( mode="constant", value=0.0, ) + do_naive_dispatch_combine: bool = ( + self.dp_size > 1 and not self.quant_method.using_modular_kernel + ) - if self.shared_experts is None: + def reduce_output( + states: torch.Tensor, do_combine: bool = True + ) -> torch.Tensor: + if do_naive_dispatch_combine and do_combine: + states = get_ep_group().combine(states, self.is_sequence_parallel) + + if ( + not self.is_sequence_parallel + and not self.use_dp_chunking + and self.reduce_results + and (self.tp_size > 1 or self.ep_size > 1) + ): + states = self.maybe_all_reduce_tensor_model_parallel(states) + return states + + if self.shared_experts is not None: if current_platform.is_tpu(): # TODO: Once the OOM issue for the TPU backend is resolved, we # will switch to using the moe_forward custom op. - fused_output = self.forward_impl(hidden_states, router_logits) - assert not isinstance(fused_output, tuple) + shared_output, fused_output = self.forward_impl( + hidden_states, router_logits + ) else: - fused_output = torch.ops.vllm.moe_forward( + shared_output, fused_output = torch.ops.vllm.moe_forward_shared( hidden_states, router_logits, self.layer_name ) - return fused_output[..., :og_hidden_states] + return ( + reduce_output(shared_output[..., :og_hidden_states], do_combine=False), + reduce_output(fused_output[..., :og_hidden_states]), + ) else: if current_platform.is_tpu(): # TODO: Once the OOM issue for the TPU backend is resolved, we # will switch to using the moe_forward custom op. - shared_output, fused_output = self.forward_impl( - hidden_states, router_logits - ) + fused_output = self.forward_impl(hidden_states, router_logits) + assert not isinstance(fused_output, tuple) else: - shared_output, fused_output = torch.ops.vllm.moe_forward_shared( + fused_output = torch.ops.vllm.moe_forward( hidden_states, router_logits, self.layer_name ) - return ( - shared_output[..., :og_hidden_states], - fused_output[..., :og_hidden_states], - ) + if self.zero_expert_num is not None and self.zero_expert_num > 0: + assert isinstance(fused_output, tuple) + fused_output, zero_expert_result = fused_output + return ( + reduce_output(fused_output[..., :og_hidden_states]) + + zero_expert_result + ) + else: + return reduce_output(fused_output[..., :og_hidden_states]) def forward_cuda( self, @@ -2360,35 +2386,7 @@ def forward_impl( shared_output, final_hidden_states, ) - elif self.zero_expert_num is not None and self.zero_expert_num > 0: - assert isinstance(final_hidden_states, tuple) - final_hidden_states, zero_expert_result = final_hidden_states - - def reduce_output( - states: torch.Tensor, do_combine: bool = True - ) -> torch.Tensor: - if do_naive_dispatch_combine and do_combine: - states = get_ep_group().combine(states, self.is_sequence_parallel) - - if ( - not self.is_sequence_parallel - and self.reduce_results - and (self.tp_size > 1 or self.ep_size > 1) - ): - states = self.maybe_all_reduce_tensor_model_parallel(states) - - return states - - if self.shared_experts is not None: - return ( - reduce_output(final_hidden_states[0], do_combine=False), - reduce_output(final_hidden_states[1]), - ) - elif self.zero_expert_num is not None and self.zero_expert_num > 0: - assert isinstance(final_hidden_states, torch.Tensor) - return reduce_output(final_hidden_states) + zero_expert_result - else: - return reduce_output(final_hidden_states) + return final_hidden_states @classmethod def make_expert_params_mapping( From 0248dcdf9e6002b925cf399cb8d39c2e8f5d2214 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 16 Oct 2025 12:51:23 +0000 Subject: [PATCH 078/137] Linter fixes Signed-off-by: ilmarkov --- vllm/compilation/collective_fusion.py | 44 +++++++++++++++------------ vllm/config/compilation.py | 20 ++++++------ 2 files changed, 34 insertions(+), 30 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index c99c63aedc2a..01a0ebc993ae 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -9,7 +9,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from torch.distributed._symmetric_memory import enable_symm_mem_for_group -from vllm.config import VllmConfig, set_current_vllm_config +from vllm.config import VllmConfig from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, @@ -492,19 +492,22 @@ def call_trtllm_fused_allreduce_norm( max_tensor_size = max_token_num * hidden_size * element_size if current_tensor_size <= max_tensor_size: - device_capability = current_platform.get_device_capability( - ).as_version_str() + device_capability = ( + current_platform.get_device_capability().as_version_str() + ) # Get one shot input size limit for the current world size # for the current device capability - max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES. \ - get(device_capability, {}). \ - get(world_size, None) + max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES.get( + device_capability, {} + ).get(world_size, None) # Use one shot if no max size is specified - use_oneshot = max_one_shot_size is None or \ - current_tensor_size <= max_one_shot_size + use_oneshot = ( + max_one_shot_size is None or current_tensor_size <= max_one_shot_size + ) - assert (_FI_WORKSPACE_TENSOR is not None - ), "Flashinfer must be enabled when using flashinfer" + assert _FI_WORKSPACE_TENSOR is not None, ( + "Flashinfer must be enabled when using flashinfer" + ) if norm_out is None: norm_out = allreduce_in residual_out = residual @@ -541,8 +544,7 @@ def call_trtllm_fused_allreduce_norm( ) else: allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) - if (scale_factor is not None and scale_out is None and - fuse_rms_quant): + if scale_factor is not None and scale_out is None and fuse_rms_quant: # Do fused rms norm static fp8 quant fused op if norm_out is None: torch.ops._C.fused_add_rms_norm_static_fp8_quant( @@ -636,7 +638,7 @@ def __init__( self.fp32_acc = True self.max_token_num = max_token_num self.fuse_rms_quant = fuse_rms_quant - + def get_trtllm_fused_allreduce_kwargs(self): return { "world_rank": self.rank, @@ -1092,6 +1094,7 @@ def replacement( pattern, replacement, get_inputs(), pm.fwd_only, pm_pass ) + class AllReduceFusionPass(VllmPatternMatcherPass): def __init__(self, config: VllmConfig): super().__init__(config) @@ -1114,8 +1117,9 @@ def __init__(self, config: VllmConfig): "skipping allreduce fusion pass" ) return - max_size = config.compilation_config.\ - pass_config.flashinfer_max_size(self.tp_size) + max_size = config.compilation_config.pass_config.flashinfer_max_size( + self.tp_size + ) if max_size is None: # Flashinfer doesn't support current world size logger.warning( @@ -1124,11 +1128,12 @@ def __init__(self, config: VllmConfig): ) return element_size = 4 if use_fp32_lamport else 2 - max_token_num = (max_size // (self.hidden_dim * element_size)) + max_token_num = max_size // (self.hidden_dim * element_size) # take the min to save workspace size and we'll never use more # than max_num_batched_tokens anyways - max_token_num = min(max_token_num, - config.scheduler_config.max_num_batched_tokens) + max_token_num = min( + max_token_num, config.scheduler_config.max_num_batched_tokens + ) self.ipc_handles, workspace_tensor = ( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( @@ -1150,7 +1155,8 @@ def __init__(self, config: VllmConfig): max_token_num=max_token_num, # fuse rms norm static fp8 quant fused op # in fallback path, when we don't use flashinfer - fuse_rms_quant=config.compilation_config.pass_config.enable_fusion) + fuse_rms_quant=config.compilation_config.pass_config.enable_fusion, + ) self.register_patterns() self.dump_patterns(config, self.patterns) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index ff3a092fe538..ee0c40f4ef42 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -109,8 +109,7 @@ class PassConfig: """Whether to enable async TP.""" enable_fi_allreduce_fusion: bool = False """Whether to enable flashinfer allreduce fusion.""" - fi_allreduce_fusion_max_size_mb: dict[int, - float] = field(default_factory=dict) + fi_allreduce_fusion_max_size_mb: dict[int, float] = field(default_factory=dict) """The thresholds of the communicated tensor sizes under which vllm should use flashinfer fused allreduce. Specified as a dictionary mapping each world size to the threshold in MB @@ -131,7 +130,7 @@ class PassConfig: # TODO(luka) better pass enabling system. - def flashinfer_max_size(self, world_size: int) -> Optional[int]: + def flashinfer_max_size(self, world_size: int) -> int | None: """ Returns the max communication size in bytes for flashinfer allreduce fusion for the given world size. Returns None if world size @@ -140,15 +139,15 @@ def flashinfer_max_size(self, world_size: int) -> Optional[int]: # import here to avoid circular dependencies from vllm.platforms import current_platform + MiB = 1024 * 1024 - device_capability = current_platform.get_device_capability( - ).as_version_str() - fi_allreduce_fusion_max_size_mb = \ - self.fi_allreduce_fusion_max_size_mb.get(device_capability, {}) + device_capability = current_platform.get_device_capability().as_version_str() + fi_allreduce_fusion_max_size_mb = self.fi_allreduce_fusion_max_size_mb.get( + device_capability, {} + ) max_sizes = { - k: int(v * MiB) - for k, v in fi_allreduce_fusion_max_size_mb.items() + k: int(v * MiB) for k, v in fi_allreduce_fusion_max_size_mb.items() } # return None if world size is not supported by flashinfer return max_sizes.get(world_size) @@ -197,8 +196,7 @@ def __post_init__(self) -> None: 8: 1, # 1MB }, } - device_capability = current_platform.get_device_capability( - ).as_version_str() + device_capability = current_platform.get_device_capability().as_version_str() max_sizes = fi_allreduce_fusion_max_size_mb.get(device_capability, {}) max_sizes.update(self.fi_allreduce_fusion_max_size_mb) From 18e477160a207d73d3761c23d35ac78f19372d02 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 16 Oct 2025 13:26:23 +0000 Subject: [PATCH 079/137] Upd Signed-off-by: ilmarkov --- .../kernels/benchmark_fused_collective.py | 59 +++++++++---------- tests/compile/test_fusion_all_reduce.py | 2 +- vllm/config/compilation.py | 25 +++----- 3 files changed, 39 insertions(+), 47 deletions(-) diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py index 7f012af36a94..0d1ec49e3f41 100644 --- a/benchmarks/kernels/benchmark_fused_collective.py +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -17,7 +17,6 @@ import itertools import os import time -from typing import Optional import torch # type: ignore import torch.distributed as dist # type: ignore @@ -156,12 +155,12 @@ def get_trtllm_fused_allreduce_kwargs(self): def flashinfer_fused_allreduce_rmsnorm( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rms_gamma: torch.Tensor, rms_eps: float, allreduce_params: "FlashInferFusedAllReduceParams", use_oneshot: bool, - norm_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, ): """FlashInfer fused allreduce + rmsnorm operation.""" if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: @@ -196,14 +195,14 @@ def flashinfer_fused_allreduce_rmsnorm( def flashinfer_fused_allreduce_rmsnorm_fp8_quant( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rms_gamma: torch.Tensor, rms_eps: float, scale_factor: torch.Tensor, allreduce_params: FlashInferFusedAllReduceParams, use_oneshot: bool = True, - norm_out: Optional[torch.Tensor] = None, - quant_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, + quant_out: torch.Tensor | None = None, ): """FlashInfer fused allreduce + rmsnorm + FP8 quantization.""" if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: @@ -238,7 +237,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant( def flashinfer_fused_allreduce_rmsnorm_fp4_quant( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rms_gamma: torch.Tensor, rms_eps: float, input_global_scale: torch.Tensor, @@ -246,7 +245,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant( quant_out: torch.Tensor, use_oneshot: bool, output_scale: torch.Tensor, - norm_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, ): """FlashInfer fused allreduce + rmsnorm + FP4 quantization.""" if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: @@ -281,10 +280,10 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant( def standard_allreduce_rmsnorm( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rms_gamma: torch.Tensor, rms_eps: float, - norm_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, ): """Standard allreduce + rmsnorm operations.""" # All-reduce first @@ -302,12 +301,12 @@ def standard_allreduce_rmsnorm( def standard_allreduce_rmsnorm_fp8_quant( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rms_gamma: torch.Tensor, rms_eps: float, scale_factor: torch.Tensor, - norm_out: Optional[torch.Tensor] = None, - quant_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, + quant_out: torch.Tensor | None = None, ): """Standard allreduce + rmsnorm + FP8 quantization.""" if quant_out is None: @@ -331,13 +330,13 @@ def standard_allreduce_rmsnorm_fp8_quant( def standard_allreduce_rmsnorm_fp4_quant( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rms_gamma: torch.Tensor, rms_eps: float, input_global_scale: torch.Tensor, quant_out: torch.Tensor, output_scale: torch.Tensor, - norm_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, ): """Standard allreduce + rmsnorm + FP4 quantization.""" @@ -366,9 +365,9 @@ def standard_allreduce_rmsnorm_fp4_quant( def standard_allreduce_rmsnorm_native( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rmsnorm_layer: RMSNorm, - norm_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, ): """Standard allreduce + rmsnorm operations using native RMSNorm forward.""" # All-reduce first @@ -384,12 +383,12 @@ def standard_allreduce_rmsnorm_native( def standard_allreduce_rmsnorm_fp8_quant_native( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rmsnorm_layer: RMSNorm, quant_fp8_layer: QuantFP8, scale_factor: torch.Tensor, - norm_out: Optional[torch.Tensor] = None, - quant_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, + quant_out: torch.Tensor | None = None, ): """Standard allreduce + rmsnorm + FP8 quantization using native implementations.""" # All-reduce first @@ -413,12 +412,12 @@ def standard_allreduce_rmsnorm_fp8_quant_native( def standard_allreduce_rmsnorm_fp4_quant_native( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rmsnorm_layer: RMSNorm, input_global_scale: torch.Tensor, quant_out: torch.Tensor, output_scale: torch.Tensor, - norm_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, ): """Standard allreduce + rmsnorm + FP4 quantization using native RMSNorm.""" # All-reduce first @@ -446,9 +445,9 @@ def standard_allreduce_rmsnorm_fp4_quant_native( @torch.compile def standard_allreduce_rmsnorm_native_compiled( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rmsnorm_layer: RMSNorm, - norm_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, ): """Compiled version of standard allreduce + rmsnorm.""" return standard_allreduce_rmsnorm_native( @@ -459,12 +458,12 @@ def standard_allreduce_rmsnorm_native_compiled( @torch.compile def standard_allreduce_rmsnorm_fp8_quant_native_compiled( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rmsnorm_layer: RMSNorm, quant_fp8_layer: QuantFP8, scale_factor: torch.Tensor, - norm_out: Optional[torch.Tensor] = None, - quant_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, + quant_out: torch.Tensor | None = None, ): """Compiled version of standard allreduce + rmsnorm + FP8 quantization.""" return standard_allreduce_rmsnorm_fp8_quant_native( @@ -481,12 +480,12 @@ def standard_allreduce_rmsnorm_fp8_quant_native_compiled( @torch.compile def standard_allreduce_rmsnorm_fp4_quant_native_compiled( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rmsnorm_layer: RMSNorm, input_global_scale: torch.Tensor, quant_out: torch.Tensor, output_scale: torch.Tensor, - norm_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, ): """Compiled version of standard allreduce + rmsnorm + FP4 quantization.""" return standard_allreduce_rmsnorm_fp4_quant_native( @@ -578,7 +577,7 @@ def run_benchmarks( hidden_dim: int, dtype: torch.dtype, use_residual: bool, - allreduce_params: Optional[FlashInferFusedAllReduceParams], + allreduce_params: FlashInferFusedAllReduceParams | None, quant_mode: str = "all", disable_oneshot: bool = False, ): diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 4798dbf1df1e..7688ba3d1b6c 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -329,4 +329,4 @@ def all_reduce_fusion_pass_on_test_model( ) backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) backend.check_after_ops(model.ops_in_model_after()) - del all_reduce_fusion_pass \ No newline at end of file + del all_reduce_fusion_pass diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index ee0c40f4ef42..2ca1959d10b0 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -117,14 +117,14 @@ class PassConfig: Unspecified world sizes will fallback to _FI_ALLREDUCE_MAX_INPUT_SIZES = { "9.0": { - 2: 64 * MiB, # 64MB - 4: 2 * MiB, # 2MB - 8: 1 * MiB, # 1MB + 2: 64, # 64MB + 4: 2, # 2MB + 8: 1, # 1MB }, "10.0": { - 2: 64 * MiB, # 64MB - 4: 32 * MiB, # 32MB - 8: 1 * MiB, # 1MB + 2: 64, # 64MB + 4: 32, # 32MB + 8: 1, # 1MB }, }, where key is the device capability""" @@ -137,18 +137,11 @@ def flashinfer_max_size(self, world_size: int) -> int | None: is not supported by configs as it's not supported by flashinfer. """ - # import here to avoid circular dependencies - from vllm.platforms import current_platform - MiB = 1024 * 1024 - - device_capability = current_platform.get_device_capability().as_version_str() - fi_allreduce_fusion_max_size_mb = self.fi_allreduce_fusion_max_size_mb.get( - device_capability, {} - ) max_sizes = { - k: int(v * MiB) for k, v in fi_allreduce_fusion_max_size_mb.items() + k: int(v * MiB) for k, v in self.fi_allreduce_fusion_max_size_mb.items() } + # return None if world size is not supported by flashinfer return max_sizes.get(world_size) @@ -200,7 +193,7 @@ def __post_init__(self) -> None: max_sizes = fi_allreduce_fusion_max_size_mb.get(device_capability, {}) max_sizes.update(self.fi_allreduce_fusion_max_size_mb) - self.fi_allreduce_fusion_max_size_mb[device_capability] = max_sizes + self.fi_allreduce_fusion_max_size_mb = max_sizes @config From 9516d2bd3b8910d439985bcc9e1eb7377ce5348c Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 21 Oct 2025 13:40:33 +0000 Subject: [PATCH 080/137] Upd after review Signed-off-by: ilmarkov --- vllm/compilation/collective_fusion.py | 76 ++++++++++--------- vllm/config/compilation.py | 36 ++++----- vllm/model_executor/layers/fused_moe/layer.py | 39 +++++----- 3 files changed, 72 insertions(+), 79 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 4afa99a38760..056d3f482751 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -449,24 +449,40 @@ def __call__(self, graph: fx.Graph): logger.debug("Replaced %s patterns", self.matched_count) +# Max size of the input tensor per world size per device capability +# to use flashinfer fused allreduce +FI_ALLREDUCE_FUSION_MAX_SIZE_MB = { + "9.0": { + 2: 64, # 64MB + 4: 2, # 2MB + 8: 1, # 1MB + }, + "10.0": { + 2: 64, # 64MB + 4: 32, # 32MB + 8: 1, # 1MB + }, +} + +# Max size of the input tensor per world size per device capability +# to use flashinfer one shot fused allreduce +_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB = { + "9.0": { + 2: 32, # 32MB + 4: 2, # 2MB + 8: 1, # 1MB + }, + "10.0": { + 2: 32, # 32MB + 4: 4, # 4MB + 8: 1, # 1MB + }, +} + + if flashinfer_comm is not None: _FI_WORKSPACE_TENSOR = None - MiB = 1024 * 1024 - # Max size of the input tensor per world size per device capability - # to use flashinfer one shot fused allreduce - _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES = { - "9.0": { - 2: 32 * MiB, # 32MB - 4: 2 * MiB, # 2MB - 8: 1 * MiB, # 1MB - }, - "10.0": { - 2: 32 * MiB, # 32MB - 4: 4 * MiB, # 4MB - 8: 1 * MiB, # 1MB - }, - } def call_trtllm_fused_allreduce_norm( allreduce_in: torch.Tensor, @@ -480,7 +496,6 @@ def call_trtllm_fused_allreduce_norm( fp32_acc: bool, max_token_num: int, pattern_code: int, - fuse_rms_quant: bool, norm_out: torch.Tensor | None = None, quant_out: torch.Tensor | None = None, scale_out: torch.Tensor | None = None, @@ -497,12 +512,13 @@ def call_trtllm_fused_allreduce_norm( ) # Get one shot input size limit for the current world size # for the current device capability - max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES.get( + max_one_shot_size_mb = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get( device_capability, {} ).get(world_size, None) - # Use one shot if no max size is specified + # Use one shot if no max size for one shot is specified use_oneshot = ( - max_one_shot_size is None or current_tensor_size <= max_one_shot_size + max_one_shot_size_mb is None + or current_tensor_size <= max_one_shot_size_mb * MiB ) assert _FI_WORKSPACE_TENSOR is not None, ( @@ -544,7 +560,7 @@ def call_trtllm_fused_allreduce_norm( ) else: allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) - if scale_factor is not None and scale_out is None and fuse_rms_quant: + if scale_factor is not None and scale_out is None: # Do fused rms norm static fp8 quant fused op if norm_out is None: torch.ops._C.fused_add_rms_norm_static_fp8_quant( @@ -567,15 +583,10 @@ def call_trtllm_fused_allreduce_norm( norm_out = allreduce_out else: torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps) - if scale_factor is not None: - if scale_out is not None: - torch.ops._C.scaled_fp4_quant( - quant_out, norm_out, scale_out, scale_factor - ) - else: - torch.ops._C.static_scaled_fp8_quant( - quant_out, norm_out, scale_factor - ) + if scale_factor is not None and scale_out is not None: + torch.ops._C.scaled_fp4_quant( + quant_out, norm_out, scale_out, scale_factor + ) if scale_factor is None or norm_out is not None: # we need to return allreduce output # in cases of non quant fused AR + RMS norm @@ -594,7 +605,6 @@ def call_trtllm_fused_allreduce_norm_fake( fp32_acc: bool, max_token_num: int, pattern_code: int, - fuse_rms_quant: bool, norm_out: torch.Tensor | None = None, quant_out: torch.Tensor | None = None, scale_out: torch.Tensor | None = None, @@ -628,7 +638,6 @@ def __init__( world_size: int, use_fp32_lamport: bool = False, max_token_num: int = 1024, - fuse_rms_quant: bool = False, ): self.rank = rank self.world_size = world_size @@ -637,7 +646,6 @@ def __init__( self.launch_with_pdl = True self.fp32_acc = True self.max_token_num = max_token_num - self.fuse_rms_quant = fuse_rms_quant def get_trtllm_fused_allreduce_kwargs(self): return { @@ -647,7 +655,6 @@ def get_trtllm_fused_allreduce_kwargs(self): "trigger_completion_at_end": self.trigger_completion_at_end, "fp32_acc": self.fp32_acc, "max_token_num": self.max_token_num, - "fuse_rms_quant": self.fuse_rms_quant, } @@ -1153,9 +1160,6 @@ def __init__(self, config: VllmConfig): world_size=self.tp_size, use_fp32_lamport=use_fp32_lamport, max_token_num=max_token_num, - # fuse rms norm static fp8 quant fused op - # in fallback path, when we don't use flashinfer - fuse_rms_quant=config.compilation_config.pass_config.enable_fusion, ) self.register_patterns() diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index ee3b5cd94870..1ed5fcc8b9a8 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -116,7 +116,7 @@ class PassConfig: dictionary mapping each world size to the threshold in MB { : } Unspecified world sizes will fallback to - _FI_ALLREDUCE_MAX_INPUT_SIZES = { + FI_ALLREDUCE_FUSION_MAX_SIZE_MB = { "9.0": { 2: 64, # 64MB 4: 2, # 2MB @@ -146,6 +146,15 @@ def flashinfer_max_size(self, world_size: int) -> int | None: # return None if world size is not supported by flashinfer return max_sizes.get(world_size) + @staticmethod + def default_fi_allreduce_fusion_max_size_mb() -> dict[int, float]: + from vllm.compilation.collective_fusion import FI_ALLREDUCE_FUSION_MAX_SIZE_MB + from vllm.platforms import current_platform + + return FI_ALLREDUCE_FUSION_MAX_SIZE_MB.get( + current_platform.get_device_capability().as_version_str(), {} + ) + def uuid(self): """ Produces a hash unique to the pass configuration. @@ -172,29 +181,10 @@ def __post_init__(self) -> None: "Allreduce + rms norm + quant (fp8) fusion might not work" ) - # import here to avoid circular dependencies - from vllm.platforms import current_platform - - # Default tuned max size of the input tensor - # per world size per device capability - # to use flashinfer fused allreduce - fi_allreduce_fusion_max_size_mb = { - "9.0": { - 2: 64, # 64MB - 4: 2, # 2MB - 8: 1, # 1MB - }, - "10.0": { - 2: 64, # 64MB - 4: 32, # 32MB - 8: 1, # 1MB - }, + self.fi_allreduce_fusion_max_size_mb = { + **PassConfig.default_fi_allreduce_fusion_max_size_mb(), + **self.fi_allreduce_fusion_max_size_mb, } - device_capability = current_platform.get_device_capability().as_version_str() - - max_sizes = fi_allreduce_fusion_max_size_mb.get(device_capability, {}) - max_sizes.update(self.fi_allreduce_fusion_max_size_mb) - self.fi_allreduce_fusion_max_size_mb = max_sizes @config diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 4a3ebbd74540..68982e37d825 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -2148,22 +2148,7 @@ def reduce_output( states = self.maybe_all_reduce_tensor_model_parallel(states) return states - if self.shared_experts is not None: - if current_platform.is_tpu(): - # TODO: Once the OOM issue for the TPU backend is resolved, we - # will switch to using the moe_forward custom op. - shared_output, fused_output = self.forward_impl( - hidden_states, router_logits - ) - else: - shared_output, fused_output = torch.ops.vllm.moe_forward_shared( - hidden_states, router_logits, self.layer_name - ) - return ( - reduce_output(shared_output[..., :og_hidden_states], do_combine=False), - reduce_output(fused_output[..., :og_hidden_states]), - ) - else: + if self.shared_experts is None: if current_platform.is_tpu(): # TODO: Once the OOM issue for the TPU backend is resolved, we # will switch to using the moe_forward custom op. @@ -2176,12 +2161,26 @@ def reduce_output( if self.zero_expert_num is not None and self.zero_expert_num > 0: assert isinstance(fused_output, tuple) fused_output, zero_expert_result = fused_output - return ( - reduce_output(fused_output[..., :og_hidden_states]) - + zero_expert_result + return (reduce_output(fused_output) + zero_expert_result)[ + ..., :og_hidden_states + ] + else: + return reduce_output(fused_output)[..., :og_hidden_states] + else: + if current_platform.is_tpu(): + # TODO: Once the OOM issue for the TPU backend is resolved, we + # will switch to using the moe_forward custom op. + shared_output, fused_output = self.forward_impl( + hidden_states, router_logits ) else: - return reduce_output(fused_output[..., :og_hidden_states]) + shared_output, fused_output = torch.ops.vllm.moe_forward_shared( + hidden_states, router_logits, self.layer_name + ) + return ( + reduce_output(shared_output, do_combine=False)[..., :og_hidden_states], + reduce_output(fused_output)[..., :og_hidden_states], + ) def forward_cuda( self, From b789044ffe53f1e789cc8bae0f87109389f805e7 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 27 Oct 2025 11:09:43 +0000 Subject: [PATCH 081/137] Update fused_moe Signed-off-by: ilmarkov --- vllm/model_executor/layers/fused_moe/layer.py | 112 +++++++++--------- 1 file changed, 59 insertions(+), 53 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 68982e37d825..2aa42bd61a90 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -2133,54 +2133,69 @@ def forward_native( self.dp_size > 1 and not self.quant_method.using_modular_kernel ) - def reduce_output( - states: torch.Tensor, do_combine: bool = True - ) -> torch.Tensor: - if do_naive_dispatch_combine and do_combine: - states = get_ep_group().combine(states, self.is_sequence_parallel) - - if ( - not self.is_sequence_parallel - and not self.use_dp_chunking - and self.reduce_results - and (self.tp_size > 1 or self.ep_size > 1) - ): - states = self.maybe_all_reduce_tensor_model_parallel(states) - return states + ctx = get_forward_context() + sp_ctx = ( + ctx.dp_metadata.sp_local_sizes(self.sp_size) + if ctx.dp_metadata + else nullcontext() + ) - if self.shared_experts is None: - if current_platform.is_tpu(): - # TODO: Once the OOM issue for the TPU backend is resolved, we - # will switch to using the moe_forward custom op. - fused_output = self.forward_impl(hidden_states, router_logits) - assert not isinstance(fused_output, tuple) - else: - fused_output = torch.ops.vllm.moe_forward( - hidden_states, router_logits, self.layer_name - ) - if self.zero_expert_num is not None and self.zero_expert_num > 0: - assert isinstance(fused_output, tuple) - fused_output, zero_expert_result = fused_output - return (reduce_output(fused_output) + zero_expert_result)[ - ..., :og_hidden_states - ] - else: - return reduce_output(fused_output)[..., :og_hidden_states] - else: - if current_platform.is_tpu(): - # TODO: Once the OOM issue for the TPU backend is resolved, we - # will switch to using the moe_forward custom op. - shared_output, fused_output = self.forward_impl( - hidden_states, router_logits + with sp_ctx: + if do_naive_dispatch_combine: + hidden_states, router_logits = get_ep_group().dispatch( + hidden_states, router_logits, self.is_sequence_parallel ) + + def reduce_output( + states: torch.Tensor, do_combine: bool = True + ) -> torch.Tensor: + if do_naive_dispatch_combine and do_combine: + states = get_ep_group().combine(states, self.is_sequence_parallel) + + if ( + not self.is_sequence_parallel + and not self.use_dp_chunking + and self.reduce_results + and (self.tp_size > 1 or self.ep_size > 1) + ): + states = self.maybe_all_reduce_tensor_model_parallel(states) + return states + + if self.shared_experts is None: + if current_platform.is_tpu(): + # TODO: Once the OOM issue for the TPU backend is resolved, we + # will switch to using the moe_forward custom op. + fused_output = self.forward_impl(hidden_states, router_logits) + assert not isinstance(fused_output, tuple) + else: + fused_output = torch.ops.vllm.moe_forward( + hidden_states, router_logits, self.layer_name + ) + if self.zero_expert_num is not None and self.zero_expert_num > 0: + assert isinstance(fused_output, tuple) + fused_output, zero_expert_result = fused_output + return (reduce_output(fused_output) + zero_expert_result)[ + ..., :og_hidden_states + ] + else: + return reduce_output(fused_output)[..., :og_hidden_states] else: - shared_output, fused_output = torch.ops.vllm.moe_forward_shared( - hidden_states, router_logits, self.layer_name + if current_platform.is_tpu(): + # TODO: Once the OOM issue for the TPU backend is resolved, we + # will switch to using the moe_forward custom op. + shared_output, fused_output = self.forward_impl( + hidden_states, router_logits + ) + else: + shared_output, fused_output = torch.ops.vllm.moe_forward_shared( + hidden_states, router_logits, self.layer_name + ) + return ( + reduce_output(shared_output, do_combine=False)[ + ..., :og_hidden_states + ], + reduce_output(fused_output)[..., :og_hidden_states], ) - return ( - reduce_output(shared_output, do_combine=False)[..., :og_hidden_states], - reduce_output(fused_output)[..., :og_hidden_states], - ) def forward_cuda( self, @@ -2349,10 +2364,6 @@ def forward_impl( if self.use_dp_chunking: return self.forward_impl_chunked(hidden_states, router_logits) - do_naive_dispatch_combine: bool = ( - self.dp_size > 1 and not self.quant_method.using_modular_kernel - ) - # If there are shared experts but we are not using a modular kernel, the # shared experts must be called here if ( @@ -2371,11 +2382,6 @@ def forward_impl( ) with sp_ctx: - if do_naive_dispatch_combine: - hidden_states, router_logits = get_ep_group().dispatch( - hidden_states, router_logits, self.is_sequence_parallel - ) - # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, From 60776166815143a4a55ae4c5175caf42485035e3 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Sun, 2 Nov 2025 19:49:31 +0000 Subject: [PATCH 082/137] Address comments Signed-off-by: ilmarkov --- .../kernels/benchmark_fused_collective.py | 755 +++++++----------- vllm/compilation/collective_fusion.py | 3 +- 2 files changed, 304 insertions(+), 454 deletions(-) diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py index 0d1ec49e3f41..d5619287530d 100644 --- a/benchmarks/kernels/benchmark_fused_collective.py +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -17,10 +17,13 @@ import itertools import os import time +from collections.abc import Callable +import pandas as pd import torch # type: ignore import torch.distributed as dist # type: ignore +from vllm.config.vllm import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.distributed import ( get_tp_group, tensor_model_parallel_all_reduce, @@ -278,225 +281,54 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant( ) -def standard_allreduce_rmsnorm( - input_tensor: torch.Tensor, - residual: torch.Tensor | None, - rms_gamma: torch.Tensor, - rms_eps: float, - norm_out: torch.Tensor | None = None, -): - """Standard allreduce + rmsnorm operations.""" - # All-reduce first - allreduce_out = tensor_model_parallel_all_reduce(input_tensor) - # Then RMS norm - if residual is not None: - # Fused add + RMS norm - FUSED_ADD_RMS_NORM_OP(allreduce_out, residual, rms_gamma, rms_eps) - else: - # Just RMS norm - if norm_out is None: - norm_out = torch.empty_like(allreduce_out) - RMS_NORM_OP(norm_out, allreduce_out, rms_gamma, rms_eps) - - -def standard_allreduce_rmsnorm_fp8_quant( - input_tensor: torch.Tensor, - residual: torch.Tensor | None, - rms_gamma: torch.Tensor, - rms_eps: float, - scale_factor: torch.Tensor, - norm_out: torch.Tensor | None = None, - quant_out: torch.Tensor | None = None, -): - """Standard allreduce + rmsnorm + FP8 quantization.""" - if quant_out is None: - quant_out = torch.empty_like(input_tensor, dtype=FP8_DTYPE) - - # All-reduce first - allreduce_out = tensor_model_parallel_all_reduce(input_tensor) - - # Then fused RMS norm + FP8 quantization - if residual is not None: - FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP( - quant_out, allreduce_out, residual, rms_gamma, scale_factor, rms_eps +class VllmFusedAllreduce: + def __init__(self, hidden_dim, dtype): + self.rms_eps = 1e-6 + self.rms_norm = RMSNorm(hidden_dim, eps=self.rms_eps, dtype=dtype) + self.fp8_quant = QuantFP8( + static=True, + group_shape=GroupShape.PER_TENSOR, ) - return quant_out, residual - else: - RMS_NORM_STATIC_FP8_QUANT_OP( - quant_out, allreduce_out, rms_gamma, scale_factor, rms_eps - ) - return quant_out + def allreduce_rmsnorm( + self, input_tensor: torch.Tensor, residual: torch.Tensor | None + ): + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + return self.rms_norm(allreduce_out, residual) -def standard_allreduce_rmsnorm_fp4_quant( - input_tensor: torch.Tensor, - residual: torch.Tensor | None, - rms_gamma: torch.Tensor, - rms_eps: float, - input_global_scale: torch.Tensor, - quant_out: torch.Tensor, - output_scale: torch.Tensor, - norm_out: torch.Tensor | None = None, -): - """Standard allreduce + rmsnorm + FP4 quantization.""" - - # All-reduce first - allreduce_out = tensor_model_parallel_all_reduce(input_tensor) - - # Then RMS norm - if residual is not None: - FUSED_ADD_RMS_NORM_OP(allreduce_out, residual, rms_gamma, rms_eps) - quant_input = allreduce_out - residual_out = residual - else: - if norm_out is None: - norm_out = torch.empty_like(allreduce_out) - RMS_NORM_OP(norm_out, allreduce_out, rms_gamma, rms_eps) - quant_input = norm_out - residual_out = allreduce_out - - # Finally FP4 quantization - SCALED_FP4_QUANT_OP(quant_out, quant_input, output_scale, input_global_scale) - if residual is not None: - return quant_out, residual_out, output_scale - else: - return quant_out, norm_out - - -def standard_allreduce_rmsnorm_native( - input_tensor: torch.Tensor, - residual: torch.Tensor | None, - rmsnorm_layer: RMSNorm, - norm_out: torch.Tensor | None = None, -): - """Standard allreduce + rmsnorm operations using native RMSNorm forward.""" - # All-reduce first - allreduce_out = tensor_model_parallel_all_reduce(input_tensor) - # Apply native RMSNorm - if residual is not None: - result = rmsnorm_layer.forward_native(allreduce_out, residual) - return result # Returns (norm_out, residual_out) - else: - result = rmsnorm_layer.forward_native(allreduce_out) - return result # Returns norm_out - - -def standard_allreduce_rmsnorm_fp8_quant_native( - input_tensor: torch.Tensor, - residual: torch.Tensor | None, - rmsnorm_layer: RMSNorm, - quant_fp8_layer: QuantFP8, - scale_factor: torch.Tensor, - norm_out: torch.Tensor | None = None, - quant_out: torch.Tensor | None = None, -): - """Standard allreduce + rmsnorm + FP8 quantization using native implementations.""" - # All-reduce first - allreduce_out = tensor_model_parallel_all_reduce(input_tensor) - - # Apply native RMSNorm - if residual is not None: - norm_out, residual_out = rmsnorm_layer.forward_native(allreduce_out, residual) - else: - norm_out = rmsnorm_layer.forward_native(allreduce_out) - residual_out = allreduce_out - - # Apply native FP8 quantization - quant_out, _ = quant_fp8_layer.forward_native(norm_out, scale=scale_factor) - - if residual is not None: - return quant_out, residual_out - else: - return quant_out - - -def standard_allreduce_rmsnorm_fp4_quant_native( - input_tensor: torch.Tensor, - residual: torch.Tensor | None, - rmsnorm_layer: RMSNorm, - input_global_scale: torch.Tensor, - quant_out: torch.Tensor, - output_scale: torch.Tensor, - norm_out: torch.Tensor | None = None, -): - """Standard allreduce + rmsnorm + FP4 quantization using native RMSNorm.""" - # All-reduce first - allreduce_out = tensor_model_parallel_all_reduce(input_tensor) - - # Apply native RMSNorm - if residual is not None: - norm_out, residual_out = rmsnorm_layer.forward_native(allreduce_out, residual) - quant_input = norm_out - else: - norm_out = rmsnorm_layer.forward_native(allreduce_out) - quant_input = norm_out - residual_out = allreduce_out - - # Apply FP4 quantization (still using fused CUDA op as there's no native FP4) - SCALED_FP4_QUANT_OP(quant_out, quant_input, output_scale, input_global_scale) - - if residual is not None: - return quant_out, residual_out, output_scale - else: - return quant_out, norm_out - - -# Compiled versions of native functions -@torch.compile -def standard_allreduce_rmsnorm_native_compiled( - input_tensor: torch.Tensor, - residual: torch.Tensor | None, - rmsnorm_layer: RMSNorm, - norm_out: torch.Tensor | None = None, -): - """Compiled version of standard allreduce + rmsnorm.""" - return standard_allreduce_rmsnorm_native( - input_tensor, residual, rmsnorm_layer, norm_out - ) - - -@torch.compile -def standard_allreduce_rmsnorm_fp8_quant_native_compiled( - input_tensor: torch.Tensor, - residual: torch.Tensor | None, - rmsnorm_layer: RMSNorm, - quant_fp8_layer: QuantFP8, - scale_factor: torch.Tensor, - norm_out: torch.Tensor | None = None, - quant_out: torch.Tensor | None = None, -): - """Compiled version of standard allreduce + rmsnorm + FP8 quantization.""" - return standard_allreduce_rmsnorm_fp8_quant_native( - input_tensor, - residual, - rmsnorm_layer, - quant_fp8_layer, - scale_factor, - norm_out, - quant_out, - ) - + def allreduce_rmsnorm_fp8_quant( + self, + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + scale_factor: torch.Tensor, + ): + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + rms_out = self.rms_norm(allreduce_out, residual) + if residual is None: + quant_out = self.fp8_quant(rms_out, scale_factor) + return quant_out + else: + rms_out, residual_out = rms_out + quant_out = self.fp8_quant(rms_out, scale_factor) + return quant_out, residual_out -@torch.compile -def standard_allreduce_rmsnorm_fp4_quant_native_compiled( - input_tensor: torch.Tensor, - residual: torch.Tensor | None, - rmsnorm_layer: RMSNorm, - input_global_scale: torch.Tensor, - quant_out: torch.Tensor, - output_scale: torch.Tensor, - norm_out: torch.Tensor | None = None, -): - """Compiled version of standard allreduce + rmsnorm + FP4 quantization.""" - return standard_allreduce_rmsnorm_fp4_quant_native( - input_tensor, - residual, - rmsnorm_layer, - input_global_scale, - quant_out, - output_scale, - norm_out, - ) + def allreduce_rmsnorm_fp4_quant( + self, + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + input_global_scale: torch.Tensor, + quant_out: torch.Tensor, + output_scale: torch.Tensor, + ): + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + rms_out = self.rms_norm(allreduce_out, residual) + if residual is None: + SCALED_FP4_QUANT_OP(quant_out, rms_out, output_scale, input_global_scale) + return quant_out, output_scale + else: + rms_out, residual_out = rms_out + SCALED_FP4_QUANT_OP(quant_out, rms_out, output_scale, input_global_scale) + return quant_out, residual_out, output_scale def create_test_tensors( @@ -533,6 +365,23 @@ def create_test_tensors( ) +# From bench_per_token_quant_fp8.py +def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int): + def inner(*args): + torch._dynamo.mark_dynamic(args[arg_index], dim_index) + return fn(*args) + + return inner + + +def bench_compile(fn: Callable): + # recompile for different shapes + fwd = torch.compile(fn, fullgraph=True, dynamic=False) + + # First dim is explicitly dynamic to simulate vLLM usage + return with_dyn_arg(fwd, 0, 0) + + def benchmark_operation( operation_func, *args, warmup: int = 5, trials: int = 20, **kwargs ): @@ -578,8 +427,7 @@ def run_benchmarks( dtype: torch.dtype, use_residual: bool, allreduce_params: FlashInferFusedAllReduceParams | None, - quant_mode: str = "all", - disable_oneshot: bool = False, + quant_modes: set[str], ): """Run all benchmarks for given configuration. @@ -600,46 +448,55 @@ def run_benchmarks( rms_eps = 1e-6 results = {} + vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype) # Create RMSNorm and QuantFP8 layers once for native benchmarks - rmsnorm_layer = RMSNorm(hidden_dim, eps=rms_eps, dtype=dtype) - rmsnorm_layer.weight.data = rms_gamma - quant_fp8_layer = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) - if quant_mode in ["all", "none"]: + if "none" in quant_modes: # Standard AllReduce + RMSNorm - try: - time_ms = benchmark_operation( - standard_allreduce_rmsnorm, - input_tensor, - norm_out=norm_out, - residual=residual, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - ) - results["standard_allreduce_rmsnorm"] = time_ms - except Exception as e: - logger.error("Standard AllReduce+RMSNorm failed: %s", e) - results["standard_allreduce_rmsnorm"] = float("inf") + for custom_op in ["-rms_norm", "+rms_norm"]: + with set_current_vllm_config( + VllmConfig(compilation_config=CompilationConfig(custom_ops=[custom_op])) + ): + try: + suffix = ( + "_custom_rms_norm" if "+" in custom_op else "_native_rms_norm" + ) + time_ms = benchmark_operation( + vllm_fused_allreduce.allreduce_rmsnorm, + input_tensor, + residual=residual, + ) + results[f"standard_allreduce_{suffix}"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm failed: %s", e) + results[f"standard_allreduce_{suffix}"] = float("inf") # Standard AllReduce + RMSNorm Native Compiled - try: - time_ms = benchmark_operation( - standard_allreduce_rmsnorm_native_compiled, - input_tensor, - residual=residual, - rmsnorm_layer=rmsnorm_layer, - norm_out=norm_out, - ) - results["standard_allreduce_rmsnorm_native_compiled"] = time_ms - except Exception as e: - logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e) - results["standard_allreduce_rmsnorm_native_compiled"] = float("inf") + with set_current_vllm_config( + VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"])) + ): + try: + standard_allreduce_rmsnorm_native_compiled = torch.compile( + vllm_fused_allreduce.allreduce_rmsnorm, + fullgraph=True, + dynamic=False, + ) + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_native_compiled, + input_tensor, + residual=residual, + ) + results["standard_allreduce_rmsnorm_native_compiled"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e) + results["standard_allreduce_rmsnorm_native_compiled"] = float("inf") - # FlashInfer Fused AllReduce + RMSNorm Oneshot + # FlashInfer Fused AllReduce + RMSNorm Oneshot/Twoshot if flashinfer_comm is not None and allreduce_params is not None: - try: - if not disable_oneshot: + for use_oneshot in [True, False]: + suffix = "_oneshot" if use_oneshot else "_twoshot" + try: time_ms = benchmark_operation( flashinfer_fused_allreduce_rmsnorm, input_tensor, @@ -648,73 +505,82 @@ def run_benchmarks( rms_gamma=rms_gamma, rms_eps=rms_eps, allreduce_params=allreduce_params, - use_oneshot=True, + use_oneshot=use_oneshot, + ) + results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = time_ms + except Exception as e: + logger.error("FlashInfer Fused AllReduce+RMSNorm failed: %s", e) + results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = float( + "inf" ) - results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = time_ms - except Exception as e: - logger.error("FlashInfer Fused AllReduce+RMSNorm Oneshot failed: %s", e) - results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = float("inf") - # FlashInfer Fused AllReduce + RMSNorm Two-shot + if "fp8" in quant_modes: + # Standard AllReduce + RMSNorm + FP8 Quant + for rms_norm_custom_op in ["-rms_norm", "+rms_norm"]: + suffix = ( + "_custom_rms_norm" if "+" in rms_norm_custom_op else "_native_rms_norm" + ) + for quant_fp8_custom_op in ["-quant_fp8", "+quant_fp8"]: + suffix += ( + "_custom_quant_fp8" + if "+" in quant_fp8_custom_op + else "_native_quant_fp8" + ) + with set_current_vllm_config( + VllmConfig( + compilation_config=CompilationConfig( + custom_ops=[rms_norm_custom_op, quant_fp8_custom_op] + ) + ) + ): + try: + time_ms = benchmark_operation( + vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant, + input_tensor, + residual=residual, + scale_factor=scale_fp8, + ) + results[f"standard_allreduce{suffix}"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e) + results[f"standard_allreduce{suffix}"] = float("inf") + + # Standard AllReduce + RMSNorm + FP8 Quant Native Compiled + with set_current_vllm_config( + VllmConfig( + compilation_config=CompilationConfig( + custom_ops=["-rms_norm", "-quant_fp8"] + ) + ) + ): try: + standard_allreduce_rmsnorm_fp8_quant_native_compiled = torch.compile( + vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant, + fullgraph=True, + dynamic=False, + ) time_ms = benchmark_operation( - flashinfer_fused_allreduce_rmsnorm, + standard_allreduce_rmsnorm_fp8_quant_native_compiled, input_tensor, residual=residual, - norm_out=norm_out, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - allreduce_params=allreduce_params, - use_oneshot=False, + scale_factor=scale_fp8, + ) + results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = ( + time_ms ) - results["flashinfer_fused_allreduce_rmsnorm_twoshot"] = time_ms except Exception as e: logger.error( - "FlashInfer Fused AllReduce+RMSNorm Two-shot failed: %s", e + "Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s", e + ) + results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = float( + "inf" ) - results["flashinfer_fused_allreduce_rmsnorm_twoshot"] = float("inf") - - if quant_mode in ["all", "fp8_only"]: - # Standard AllReduce + RMSNorm + FP8 Quant - try: - time_ms = benchmark_operation( - standard_allreduce_rmsnorm_fp8_quant, - input_tensor, - norm_out=norm_out, - residual=residual, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - scale_factor=scale_fp8, - quant_out=quant_out_fp8, - ) - results["standard_allreduce_rmsnorm_fp8_quant"] = time_ms - except Exception as e: - logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e) - results["standard_allreduce_rmsnorm_fp8_quant"] = float("inf") - - # Standard AllReduce + RMSNorm + FP8 Quant Native Compiled - try: - time_ms = benchmark_operation( - standard_allreduce_rmsnorm_fp8_quant_native_compiled, - input_tensor, - residual=residual, - rmsnorm_layer=rmsnorm_layer, - quant_fp8_layer=quant_fp8_layer, - scale_factor=scale_fp8, - norm_out=norm_out, - quant_out=quant_out_fp8, - ) - results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = time_ms - except Exception as e: - logger.error("Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s", e) - results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = float( - "inf" - ) # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot if flashinfer_comm is not None and allreduce_params is not None: - try: - if not disable_oneshot: + for use_oneshot in [True, False]: + suffix = "_oneshot" if use_oneshot else "_twoshot" + try: time_ms = benchmark_operation( flashinfer_fused_allreduce_rmsnorm_fp8_quant, input_tensor, @@ -725,87 +591,81 @@ def run_benchmarks( scale_factor=scale_fp8, quant_out=quant_out_fp8, allreduce_params=allreduce_params, - use_oneshot=True, + use_oneshot=use_oneshot, ) - results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = ( + results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = ( time_ms ) - except Exception as e: - logger.error( - "FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s", - e, - ) - results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = float( - "inf" + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s", + e, + ) + results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = ( + float("inf") + ) + + if "fp4" in quant_modes and current_platform.has_device_capability(100): + # Standard AllReduce + RMSNorm + FP4 Quant + for rms_norm_custom_op in ["-rms_norm", "+rms_norm"]: + suffix = ( + "_custom_rms_norm" if "+" in rms_norm_custom_op else "_native_rms_norm" + ) + with set_current_vllm_config( + VllmConfig( + compilation_config=CompilationConfig( + custom_ops=[rms_norm_custom_op] + ) ) - # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Two-shot + ): + try: + time_ms = benchmark_operation( + vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant, + input_tensor, + residual=residual, + input_global_scale=scale_fp4, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + ) + results[f"standard_allreduce_{suffix}_fp4_quant"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP4 failed: %s", e) + results[f"standard_allreduce_{suffix}_fp4_quant"] = float("inf") + + # Standard AllReduce + RMSNorm + FP4 Quant Native Compiled + with set_current_vllm_config( + VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"])) + ): try: + standard_allreduce_rmsnorm_fp4_quant_native_compiled = torch.compile( + vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant, + fullgraph=True, + dynamic=False, + ) time_ms = benchmark_operation( - flashinfer_fused_allreduce_rmsnorm_fp8_quant, + standard_allreduce_rmsnorm_fp4_quant_native_compiled, input_tensor, - norm_out=norm_out, residual=residual, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - scale_factor=scale_fp8, - quant_out=quant_out_fp8, - allreduce_params=allreduce_params, - use_oneshot=False, + quant_out=fp4_quant_out, + input_global_scale=scale_fp4, + output_scale=fp4_output_scale, ) - results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_twoshot"] = ( + results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = ( time_ms ) except Exception as e: logger.error( - "FlashInfer Fused AllReduce+RMSNorm+FP8 Two-shot failed: %s", - e, + "Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s", e ) - results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_twoshot"] = float( + results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = float( "inf" ) - if quant_mode in ["all", "fp4_only"]: - # Standard AllReduce + RMSNorm + FP4 Quant - try: - time_ms = benchmark_operation( - standard_allreduce_rmsnorm_fp4_quant, - input_tensor, - norm_out=norm_out, - residual=residual, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - input_global_scale=scale_fp4, - quant_out=fp4_quant_out, - output_scale=fp4_output_scale, - ) - results["standard_allreduce_rmsnorm_fp4_quant"] = time_ms - except Exception as e: - logger.error("Standard AllReduce+RMSNorm+FP4 failed: %s", e) - results["standard_allreduce_rmsnorm_fp4_quant"] = float("inf") - - # Standard AllReduce + RMSNorm + FP4 Quant Native Compiled - try: - time_ms = benchmark_operation( - standard_allreduce_rmsnorm_fp4_quant_native_compiled, - input_tensor, - residual=residual, - rmsnorm_layer=rmsnorm_layer, - input_global_scale=scale_fp4, - quant_out=fp4_quant_out, - output_scale=fp4_output_scale, - norm_out=norm_out, - ) - results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = time_ms - except Exception as e: - logger.error("Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s", e) - results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = float( - "inf" - ) - # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot if flashinfer_comm is not None and allreduce_params is not None: - try: - if not disable_oneshot: + for use_oneshot in [True, False]: + suffix = "_oneshot" if use_oneshot else "_twoshot" + try: time_ms = benchmark_operation( flashinfer_fused_allreduce_rmsnorm_fp4_quant, input_tensor, @@ -817,19 +677,19 @@ def run_benchmarks( allreduce_params=allreduce_params, quant_out=fp4_quant_out, output_scale=fp4_output_scale, - use_oneshot=True, + use_oneshot=use_oneshot, ) - results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = ( + results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = ( time_ms ) - except Exception as e: - logger.error( - "FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s", - e, - ) - results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = float( - "inf" - ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s", + e, + ) + results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = ( + float("inf") + ) # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot if flashinfer_comm is not None and allreduce_params is not None: @@ -962,7 +822,7 @@ def get_fastest_baseline(op_name, results_dict): def print_results( - results_dict, seq_len, hidden_dim, dtype, use_residual, quant_mode, input_size_mb + results_dict, seq_len, hidden_dim, dtype, use_residual, quant_modes, input_size_mb ): """Print benchmark results in a formatted table.""" print(f"\n{'=' * 80}") @@ -972,7 +832,7 @@ def print_results( ) print( f"dtype={dtype}, residual={'yes' if use_residual else 'no'}, " - f"quant_mode={quant_mode}" + f"quant_modes={','.join(sorted(list(quant_modes)))}" ) print(f"{'=' * 80}") print(f"{'Operation':<50} {'Time (ms)':<12} {'Speedup':<10}") @@ -996,46 +856,51 @@ def format_results_markdown( all_results: list[dict], world_size: int, args: argparse.Namespace ) -> str: """Format all benchmark results as markdown.""" - markdown = f"""# FlashInfer Fused Collective Operations Benchmark Results - -**World Size:** {world_size} -**Hidden Dimension:** {args.hidden_dim} -**Warmup Iterations:** {args.warmup} -**Benchmark Trials:** {args.trials} -**Quantization Mode:** {all_results[0]["quant_mode"] if all_results else "N/A"} - ---- - -""" - - for result in all_results: - seq_len = result["seq_len"] - dtype = result["dtype"] - use_residual = result["use_residual"] - results_dict = result["results"] - input_size_mb = result["input_size_mb"] + lines: list[str] = [] + lines.append("# FlashInfer Fused Collective Operations Benchmark Results") + lines.append("") + lines.append(f"**World Size:** {world_size} ") + lines.append(f"**Hidden Dimension:** {args.hidden_dim} ") + lines.append(f"**Warmup Iterations:** {args.warmup} ") + lines.append(f"**Benchmark Trials:** {args.trials} ") + modes = ",".join(all_results[0]["quant_modes"]) if all_results else "N/A" + lines.append(f"**Quantization Modes:** {modes} ") + lines.append("") + lines.append("---") + lines.append("") + + for entry in all_results: + seq_len = entry["seq_len"] + dtype = entry["dtype"] + use_residual = entry["use_residual"] + results_dict = entry["results"] + input_size_mb = entry["input_size_mb"] residual_str = "with residual" if use_residual else "no residual" - markdown += f""" -## Configuration: seq_len={seq_len}, dtype={dtype}, {residual_str} -**Input Size:** {input_size_mb:.2f} MB - -| Operation | Time (ms) | Speedup | -|-----------|-----------|---------| -""" - - # Prepare results with speedup calculations - prepared_results = prepare_results_with_speedups(results_dict) - - for result in prepared_results: - # Format operation name for better readability - formatted_op_name = result["operation"].replace("_", " ").title() - markdown += f"| {formatted_op_name} | {result['time_str']} |" - markdown += f"{result['speedup_str']} |\n" + lines.append( + f"## Configuration: seq_len={seq_len}, dtype={dtype}, {residual_str}" + ) + lines.append(f"**Input Size:** {input_size_mb:.2f} MB") + lines.append("") - markdown += "\n" + prepared = prepare_results_with_speedups(results_dict) + # Build DataFrame for markdown export + rows = [ + { + "Operation": r["operation"].replace("_", " ").title(), + "Time (ms)": r["time_str"], + "Speedup": r["speedup_str"], + } + for r in prepared + ] + df = pd.DataFrame(rows) + if df.empty: + lines.append("No results.") + else: + lines.append(df.to_markdown(index=False)) + lines.append("") - return markdown + return "\n".join(lines) def save_results_to_file( @@ -1089,28 +954,16 @@ def main(): help="Skip residual connection tests", ) - # Quantization mode options (mutually exclusive with --no-quant) - quant_group = parser.add_mutually_exclusive_group() - quant_group.add_argument( - "--no-quant", action="store_true", help="Skip all quantization tests" - ) - quant_group.add_argument( - "--quant-fp8", action="store_true", help="Only run FP8 quantization tests" - ) - quant_group.add_argument( - "--quant-fp4", action="store_true", help="Only run FP4 quantization tests" - ) - quant_group.add_argument( - "--quant-all", - action="store_true", - help="Run all quantization tests (default)", - ) - parser.add_argument( - "--disable-oneshot", - action="store_true", - help="Disable oneshot mode for FlashInfer operations", + "--quant-modes", + type=str, + default="none,fp8,fp4", + help=( + "Comma-separated quantization modes to run: none, fp8, fp4. " + "Default: none,fp8,fp4" + ), ) + parser.add_argument( "--warmup", type=int, default=5, help="Number of warmup iterations" ) @@ -1152,24 +1005,25 @@ def main(): f"Current world size: {world_size}. Use torchrun with --nproc_per_node > 1." ) - # Determine quantization mode - if args.no_quant: - quant_mode = "none" - elif args.quant_fp8: - quant_mode = "fp8_only" - elif args.quant_fp4: - quant_mode = "fp4_only" - else: # args.quant_all or default - quant_mode = "all" + # Parse quantization modes + valid_quant_modes = {"none", "fp8", "fp4"} + raw_modes = [ + m.strip().lower() for m in (args.quant_modes or "").split(",") if m.strip() + ] + quant_modes = set(raw_modes) if raw_modes else {"none", "fp8", "fp4"} + invalid = sorted(list(quant_modes - valid_quant_modes)) + if invalid: + raise ValueError( + f"Invalid --quant-modes entries: {','.join(invalid)}. " + f"Valid options are: {','.join(sorted(valid_quant_modes))}." + ) if rank == 0: logger.info("Running benchmark with world_size=%s, rank=%s", world_size, rank) - logger.info("Quantization mode: %s", quant_mode) + logger.info("Quantization modes: %s", ",".join(sorted(list(quant_modes)))) if flashinfer_comm is not None: - oneshot_status = "enabled" if not args.disable_oneshot else "disabled" logger.info( - "FlashInfer available - will benchmark fused operations (oneshot: %s)", - oneshot_status, + "FlashInfer available - will benchmark fused operations", ) else: logger.info( @@ -1186,8 +1040,6 @@ def main(): # Test configurations residual_options = [True] if not args.no_residual else [False] - if not args.no_residual: - residual_options.append(False) configs = list(itertools.product(args.seq_lens, dtypes, residual_options)) @@ -1233,8 +1085,7 @@ def main(): dtype, use_residual, allreduce_params, - quant_mode=quant_mode, - disable_oneshot=args.disable_oneshot, + quant_modes=quant_modes, ) # Store results for markdown export @@ -1249,7 +1100,7 @@ def main(): "hidden_dim": args.hidden_dim, "dtype": str(dtype).replace("torch.", ""), "use_residual": use_residual, - "quant_mode": quant_mode, + "quant_modes": sorted(list(quant_modes)), "input_size_mb": input_size_mb, "results": results, } @@ -1261,7 +1112,7 @@ def main(): args.hidden_dim, dtype, use_residual, - quant_mode, + quant_modes, input_size_mb, ) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 056d3f482751..cf89182357f2 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -504,9 +504,8 @@ def call_trtllm_fused_allreduce_norm( num_tokens, hidden_size = allreduce_in.shape element_size = allreduce_in.element_size() current_tensor_size = num_tokens * hidden_size * element_size - max_tensor_size = max_token_num * hidden_size * element_size - if current_tensor_size <= max_tensor_size: + if num_tokens <= max_token_num: device_capability = ( current_platform.get_device_capability().as_version_str() ) From afc8af8b427b5f487c5716f7278a6b8b44bf63d6 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Sun, 2 Nov 2025 19:54:31 +0000 Subject: [PATCH 083/137] Remove bench_compile Signed-off-by: ilmarkov --- .../kernels/benchmark_fused_collective.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py index d5619287530d..cec134ff9138 100644 --- a/benchmarks/kernels/benchmark_fused_collective.py +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -17,7 +17,6 @@ import itertools import os import time -from collections.abc import Callable import pandas as pd import torch # type: ignore @@ -365,23 +364,6 @@ def create_test_tensors( ) -# From bench_per_token_quant_fp8.py -def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int): - def inner(*args): - torch._dynamo.mark_dynamic(args[arg_index], dim_index) - return fn(*args) - - return inner - - -def bench_compile(fn: Callable): - # recompile for different shapes - fwd = torch.compile(fn, fullgraph=True, dynamic=False) - - # First dim is explicitly dynamic to simulate vLLM usage - return with_dyn_arg(fwd, 0, 0) - - def benchmark_operation( operation_func, *args, warmup: int = 5, trials: int = 20, **kwargs ): From c3af2af0b8be65ecd1a8538bcfb9622e873e6b3c Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 4 Sep 2025 07:25:57 -0700 Subject: [PATCH 084/137] Split PR. Second part. Compile ranges Signed-off-by: ilmarkov --- tests/compile/test_compile_ranges.py | 86 ++++++++++++++ vllm/compilation/backends.py | 104 +++++++--------- vllm/compilation/collective_fusion.py | 144 +++++++++-------------- vllm/compilation/compiler_interface.py | 40 ++++--- vllm/compilation/inductor_pass.py | 11 +- vllm/compilation/pass_manager.py | 4 +- vllm/compilation/piecewise_backend.py | 57 +++++---- vllm/compilation/sequence_parallelism.py | 6 +- vllm/config/compilation.py | 33 ++++++ 9 files changed, 288 insertions(+), 197 deletions(-) create mode 100644 tests/compile/test_compile_ranges.py diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py new file mode 100644 index 000000000000..6759da199f4b --- /dev/null +++ b/tests/compile/test_compile_ranges.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from torch import nn +from torch.library import Library + +from vllm.compilation.counter import compilation_counter +from vllm.compilation.decorators import support_torch_compile +from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, + set_current_vllm_config) +from vllm.forward_context import set_forward_context +from vllm.utils import direct_register_custom_op + +# create a library to hold the custom op +silly_lib = Library("silly", "FRAGMENT") # noqa + +BATCH_SIZE = 64 +MLP_SIZE = 128 + + +def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + out.copy_(q) + out += k + out += v + + +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + return + + +direct_register_custom_op( + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, + target_lib=silly_lib, +) + + +@support_torch_compile +class TestModel(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + x + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = x * 3 + return x + + +@torch.inference_mode +def run_model(vllm_config: VllmConfig, model: nn.Module, + batch_sizes: list[int]): + with set_forward_context({}, vllm_config=vllm_config): + model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) + for batch_size in batch_sizes: + model(torch.randn(batch_size, MLP_SIZE).cuda()) + + +def test_compile_ranges(): + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + compile_ranges_split_points=[8, 32], + )) + + with set_current_vllm_config(vllm_config): + model = TestModel(vllm_config=vllm_config, prefix='').eval().cuda() + batch_sizes = [1, 16, 48] + # A has support_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=1, + num_backend_compilations=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, model, batch_sizes) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 53fd5e74dc0a..686c415f7ac3 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -80,7 +80,8 @@ class CompilerManager: """ def __init__(self, compilation_config: CompilationConfig): - self.cache: dict[tuple[int | None, int, str], Any] = dict() + self.cache: dict[tuple[tuple[int, int] | None, int, str], + Any] = (dict()) self.is_cache_updated = False self.compilation_config = compilation_config self.compiler = make_compiler(compilation_config) @@ -89,11 +90,11 @@ def compute_hash(self, vllm_config: VllmConfig) -> str: return self.compiler.compute_hash(vllm_config) @contextmanager - def compile_context(self, runtime_shape: int | None = None): + def compile_context(self, compile_range: tuple[int, int] | None = None): """Provide compilation context for the duration of compilation to set any torch global properties we want to scope to a single Inductor compilation (e.g. partition rules, pass context).""" - with pass_context(runtime_shape): + with pass_context(compile_range): if self.compilation_config.use_inductor_graph_partition: inductor_partition_ops = resolve_defined_ops( self.compilation_config.splitting_ops @@ -150,29 +151,25 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: tuple[int, int] | None = None, ) -> Callable | None: - if (runtime_shape, graph_index, self.compiler.name) not in self.cache: + if (compile_range, graph_index, self.compiler.name) not in self.cache: return None - handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] - compiled_graph = self.compiler.load( - handle, graph, example_inputs, graph_index, runtime_shape - ) - if runtime_shape is None: + handle = self.cache[(compile_range, graph_index, self.compiler.name)] + compiled_graph = self.compiler.load(handle, graph, example_inputs, + graph_index, compile_range) + if compile_range is None: logger.debug( - "Directly load the %s-th graph for dynamic shape from %s via handle %s", + "Directly load the %s-th graph for dynamic compile range from %s via handle %s", graph_index, self.compiler.name, handle, ) else: logger.debug( - "Directly load the %s-th graph for shape %s from %s via handle %s", - graph_index, - str(runtime_shape), - self.compiler.name, - handle, - ) + "Directly load the %s-th graph for compile range %s from %s via " + "handle %s", graph_index, str(compile_range), + self.compiler.name, handle) return compiled_graph def compile( @@ -183,7 +180,7 @@ def compile( compilation_config: CompilationConfig, graph_index: int = 0, num_graphs: int = 1, - runtime_shape: int | None = None, + compile_range: tuple[int, int] | None = None, ) -> Any: if graph_index == 0: # before compiling the first graph, record the start time @@ -195,15 +192,15 @@ def compile( compiled_graph = None # try to load from the cache - compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape) + compiled_graph = self.load(graph, example_inputs, graph_index, + compile_range) if compiled_graph is not None: if graph_index == num_graphs - 1: # after loading the last graph for this shape, record the time. # there can be multiple graphs due to piecewise compilation. now = time.time() elapsed = now - compilation_start_time - compilation_config.compilation_time += elapsed - if runtime_shape is None: + if compile_range is None: logger.info( "Directly load the compiled graph(s) for dynamic shape " "from the cache, took %.3f s", @@ -211,11 +208,9 @@ def compile( ) else: logger.info( - "Directly load the compiled graph(s) for shape %s " - "from the cache, took %.3f s", - str(runtime_shape), - elapsed, - ) + "Directly load the compiled graph(s) for compile range %s " + "from the cache, took %.3f s", str(compile_range), + elapsed) return compiled_graph # no compiler cached the graph, or the cache is disabled, @@ -224,48 +219,40 @@ def compile( # Let compile_fx generate a key for us maybe_key = None else: - maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" - - with self.compile_context(runtime_shape): - compiled_graph, handle = self.compiler.compile( - graph, - example_inputs, - additional_inductor_config, - runtime_shape, - maybe_key, - ) + maybe_key = \ + f"artifact_compile_range_{compile_range}_subgraph_{graph_index}" + compiled_graph, handle = self.compiler.compile( + graph, example_inputs, additional_inductor_config, compile_range, + maybe_key) assert compiled_graph is not None, "Failed to compile the graph" # store the artifact in the cache if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None: - self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle + self.cache[(compile_range, graph_index, + self.compiler.name)] = handle compilation_counter.num_cache_entries_updated += 1 self.is_cache_updated = True if graph_index == 0: # adds some info logging for the first graph - if runtime_shape is None: + if compile_range is None: logger.info_once( - "Cache the graph for dynamic shape for later use", scope="local" - ) + "Cache the graph for dynamic shape for later use") else: - logger.info_once( - "Cache the graph of shape %s for later use", - str(runtime_shape), - scope="local", - ) - if runtime_shape is None: + logger.info_once("Cache the graph of compile range %s for later use", + str(compile_range)) + if compile_range is None: logger.debug( - "Store the %s-th graph for dynamic shape from %s via handle %s", + "Store the %s-th graph for dynamic compile range from %s via handle %s", graph_index, self.compiler.name, handle, ) else: logger.debug( - "Store the %s-th graph for shape %s from %s via handle %s", + "Store the %s-th graph for compile range %s from %s via handle %s", graph_index, - str(runtime_shape), + str(compile_range), self.compiler.name, handle, ) @@ -275,19 +262,16 @@ def compile( now = time.time() elapsed = now - compilation_start_time compilation_config.compilation_time += elapsed - if runtime_shape is None: + if compile_range is None: logger.info_once( - "Compiling a graph for dynamic shape takes %.2f s", + "Compiling a graph for dynamic compile range takes %.2f s", + elapsed, scope="local", ) else: - logger.info_once( - "Compiling a graph for shape %s takes %.2f s", - runtime_shape, - elapsed, - scope="local", - ) + logger.info_once("Compiling a graph for compile range %s takes %.2f s", + str(compile_range), elapsed, scope="local") return compiled_graph @@ -408,7 +392,6 @@ def call_module( i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] global compilation_start_time - compiled_graph_for_dynamic_shape = ( self.vllm_backend.compiler_manager.compile( submod, @@ -417,9 +400,8 @@ def call_module( self.compilation_config, graph_index=index, num_graphs=len(self.compile_submod_names), - runtime_shape=None, - ) - ) + compile_range=None, + )) # Lazy import here to avoid circular import from .piecewise_backend import PiecewiseBackend diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index cf89182357f2..a4758c971611 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -504,93 +504,59 @@ def call_trtllm_fused_allreduce_norm( num_tokens, hidden_size = allreduce_in.shape element_size = allreduce_in.element_size() current_tensor_size = num_tokens * hidden_size * element_size - - if num_tokens <= max_token_num: - device_capability = ( - current_platform.get_device_capability().as_version_str() - ) - # Get one shot input size limit for the current world size - # for the current device capability - max_one_shot_size_mb = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get( - device_capability, {} - ).get(world_size, None) - # Use one shot if no max size for one shot is specified - use_oneshot = ( - max_one_shot_size_mb is None - or current_tensor_size <= max_one_shot_size_mb * MiB - ) - - assert _FI_WORKSPACE_TENSOR is not None, ( - "Flashinfer must be enabled when using flashinfer" - ) - if norm_out is None: - norm_out = allreduce_in - residual_out = residual - else: - # return residual_out as allreduce_out with zeroed residual_in - # as flashinfer does not support rms_norm - # and allreduce_out together - residual_out = allreduce_in - # For the sizes that are smaller than the max size, - # we only use flashinfer one shot allreduce - flashinfer_comm.trtllm_allreduce_fusion( - allreduce_in=allreduce_in, - token_num=allreduce_in.shape[0], - residual_in=residual, - residual_out=residual_out, - norm_out=norm_out, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - world_rank=world_rank, - world_size=world_size, - hidden_dim=allreduce_in.shape[-1], - workspace_ptrs=_FI_WORKSPACE_TENSOR, - launch_with_pdl=launch_with_pdl, - use_oneshot=use_oneshot, - trigger_completion_at_end=trigger_completion_at_end, - fp32_acc=fp32_acc, - pattern_code=pattern_code, - allreduce_out=None, - quant_out=quant_out, - scale_out=scale_out, - # in vllm we only support swizzled layout - layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, - scale_factor=scale_factor, - ) + max_tensor_size = max_token_num * hidden_size * element_size + assert current_tensor_size <= max_tensor_size, \ + f"Current tensor size {current_tensor_size} is larger than " \ + f"max token num {max_token_num} * hidden size {hidden_size} * " \ + f"element size {element_size}" + device_capability = current_platform.get_device_capability( + ).as_version_str() + # Get one shot input size limit for the current world size + # for the current device capability + max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB. \ + get(device_capability, {}). \ + get(world_size, None) + # Use one shot if no max size is specified + use_oneshot = max_one_shot_size is None or \ + current_tensor_size <= max_one_shot_size * MiB + + assert ( + _FI_WORKSPACE_TENSOR + is not None), "Flashinfer must be enabled when using flashinfer" + if norm_out is None: + norm_out = allreduce_in + residual_out = residual else: - allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) - if scale_factor is not None and scale_out is None: - # Do fused rms norm static fp8 quant fused op - if norm_out is None: - torch.ops._C.fused_add_rms_norm_static_fp8_quant( - quant_out, - allreduce_out, - residual, - rms_gamma, - scale_factor, - rms_eps, - ) - else: - torch.ops._C.rms_norm_static_fp8_quant( - quant_out, allreduce_out, rms_gamma, scale_factor, rms_eps - ) - else: - if norm_out is None: - torch.ops._C.fused_add_rms_norm( - allreduce_out, residual, rms_gamma, rms_eps - ) - norm_out = allreduce_out - else: - torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps) - if scale_factor is not None and scale_out is not None: - torch.ops._C.scaled_fp4_quant( - quant_out, norm_out, scale_out, scale_factor - ) - if scale_factor is None or norm_out is not None: - # we need to return allreduce output - # in cases of non quant fused AR + RMS norm - # and fused AR + RMS norm + quant without fused add - allreduce_in.copy_(allreduce_out) + # return residual_out as allreduce_out with zeroed residual_in + # as flashinfer does not support rms_norm + # and allreduce_out together + residual_out = allreduce_in + # For the sizes that are smaller than the max size, + # we only use flashinfer one shot allreduce + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=allreduce_in, + token_num=allreduce_in.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + world_rank=world_rank, + world_size=world_size, + hidden_dim=allreduce_in.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + launch_with_pdl=launch_with_pdl, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=pattern_code, + allreduce_out=None, + quant_out=quant_out, + scale_out=scale_out, + # in vllm we only support swizzled layout + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=scale_factor, + ) def call_trtllm_fused_allreduce_norm_fake( allreduce_in: torch.Tensor, @@ -1212,6 +1178,12 @@ def register_patterns(self): self.disabled = False @VllmInductorPass.time_and_log + def is_applicable_for_range( + self, compile_range: tuple[int, int] | None) -> bool: + if compile_range is None: + return False + return compile_range[1] - 1 <= self.max_token_num + def __call__(self, graph: fx.Graph): if self.disabled: logger.debug("AllReduceFusionPass disabled") diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 0a3f0769db94..3861bfed11d5 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -63,16 +63,17 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: tuple[int, int | None] = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: """ Compile the graph with the given example inputs and compiler config, - with a runtime shape. If the `runtime_shape` is None, it means + with a range. If the `compile_range` is None, it means the `example_inputs` have a dynamic shape. Otherwise, the - `runtime_shape` specifies the shape of the inputs. Right now we only - support one variable shape for all inputs, which is the batchsize - (number of tokens) during inference. + `compile_range` specifies the range of the inputs, + it could be concrete size, e.g. (4, 4). + Right now we only support one variable range of shapes for all inputs, + which is the batchsize (number of tokens) during inference. Dynamo will make sure `graph(*example_inputs)` is valid. @@ -98,7 +99,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: tuple[int, int | None] = None, ) -> Callable: """ Load the compiled function from the handle. @@ -192,18 +193,21 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: tuple[int, int | None] = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 current_config = {} if compiler_config is not None: current_config.update(compiler_config) - set_inductor_config(current_config, runtime_shape) + set_inductor_config(current_config, compile_range) set_functorch_config() - if isinstance(runtime_shape, int): - dynamic_shapes = "from_example_inputs" + if isinstance(compile_range, tuple): + if compile_range[0] == compile_range[1]: + dynamic_shapes = "from_example_inputs" + else: + dynamic_shapes = "from_graph" else: dynamic_shapes = "from_tracing_context" @@ -230,7 +234,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: tuple[int, int | None] = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -294,7 +298,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: tuple[int, int | None] = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -308,7 +312,7 @@ def compile( current_config["fx_graph_cache"] = True current_config["fx_graph_remote_cache"] = False - set_inductor_config(current_config, runtime_shape) + set_inductor_config(current_config, compile_range) set_functorch_config() # inductor can inplace modify the graph, so we need to copy it @@ -493,7 +497,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: tuple[int, int | None] = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -589,9 +593,9 @@ def metrics_context(self) -> contextlib.AbstractContextManager: return contextlib.nullcontext() -def set_inductor_config(config, runtime_shape): - if isinstance(runtime_shape, int): - # for a specific batchsize, tuning triton kernel parameters +def set_inductor_config(config, compile_range): + if isinstance(compile_range, tuple): + # for a specific range of batchsizes, tuning triton kernel parameters # can be beneficial config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE config["coordinate_descent_tuning"] = ( @@ -611,7 +615,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: tuple[int, int | None] = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_eager_compiles += 1 diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 9af635a929b4..1b4430c82b2d 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -28,8 +28,8 @@ class PassContext: - def __init__(self, runtime_shape: int | None): - self.runtime_shape = runtime_shape + def __init__(self, compile_range: tuple[int, int] | None): + self.compile_range = compile_range def get_pass_context() -> PassContext: @@ -39,13 +39,13 @@ def get_pass_context() -> PassContext: @contextmanager -def pass_context(runtime_shape: int | None): +def pass_context(compile_range: tuple[int, int] | None): """A context manager that stores the current pass context, usually it is a list of sizes to specialize. """ global _pass_context prev_context = _pass_context - _pass_context = PassContext(runtime_shape) + _pass_context = PassContext(compile_range) try: yield finally: @@ -96,7 +96,8 @@ def hash_dict(dict_: dict[Any, Any]): encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() - def is_applicable(self, shape: int | None): + def is_applicable_for_range(self, compile_range: tuple[int, + int] | None): return True diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 3bc35a8f7198..82bca8f1fe1b 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -69,9 +69,9 @@ def __init__(self): def __call__(self, graph: fx.Graph): VllmInductorPass.dump_prefix = 0 # reset dump index - shape = get_pass_context().runtime_shape + compile_range = get_pass_context().compile_range for pass_ in self.passes: - if pass_.is_applicable(shape): + if pass_.is_applicable_for_range(compile_range): pass_(graph) VllmInductorPass.dump_prefix += 1 else: diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 2931580afbbb..87b0121f43cb 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -7,7 +7,6 @@ import torch.fx as fx -import vllm.envs as envs from vllm.compilation.backends import VllmBackend from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig @@ -17,8 +16,8 @@ @dataclasses.dataclass -class ConcreteSizeEntry: - runtime_shape: int +class RangeEntry: + compile_range: tuple[int, int] compiled: bool = False runnable: Callable = None # type: ignore @@ -55,7 +54,12 @@ def __init__( self.is_full_graph = total_piecewise_compiles == 1 - self.compile_sizes: set[int] = set(self.compilation_config.compile_sizes) + self.compile_ranges = self.compilation_config.get_compile_ranges() + log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}" + logger.debug_once(log_string) + + self.is_in_range = lambda x, range: range[0] <= x < range[1] if range[ + 0] < range[1] else x == range[0] self.first_run_finished = False @@ -63,24 +67,27 @@ def __init__( self.sym_shape_indices = sym_shape_indices - self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" - # the entries for different shapes that we need to compile - self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} + # self.concrete_size_entries: dict[int, RangeEntry] = {} + + # the entries for ranges that we need to either + # TODO: we should merge with concrete_size_entries + self.range_entries: dict[tuple[int, int], RangeEntry] = {} - # to_be_compiled_sizes tracks the remaining sizes to compile, + # to_be_compiled_ranges tracks the remaining ranges to compile, # and updates during the compilation process, so we need to copy it - self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() + self.to_be_compiled_ranges: set[tuple[int, + int]] = set(self.compile_ranges) # We only keep compilation management inside this class directly. - for shape in self.compile_sizes: - self.concrete_size_entries[shape] = ConcreteSizeEntry( - runtime_shape=shape, + for range in self.compile_ranges: + self.range_entries[range] = RangeEntry( + compile_range=range, runnable=self.compiled_graph_for_general_shape, ) def check_for_ending_compilation(self): - if self.is_last_graph and not self.to_be_compiled_sizes: + if (self.is_last_graph and not self.to_be_compiled_ranges): # no specific sizes to compile # save the hash of the inductor graph for the next run self.vllm_backend.compiler_manager.save_to_file() @@ -94,28 +101,32 @@ def __call__(self, *args) -> Any: runtime_shape = args[self.sym_shape_indices[0]] - if runtime_shape not in self.concrete_size_entries: + range_entry = None + for range in self.compile_ranges: + if self.is_in_range(runtime_shape, range): + range_entry = self.range_entries[range] + break + + if (range_entry is None): # we don't need to do anything for this shape return self.compiled_graph_for_general_shape(*args) - entry = self.concrete_size_entries[runtime_shape] + if not range_entry.compiled: + range_entry.compiled = True + self.to_be_compiled_ranges.remove(range_entry.compile_range) - if not entry.compiled: - entry.compiled = True - self.to_be_compiled_sizes.remove(runtime_shape) # args are real arguments - entry.runnable = self.vllm_backend.compiler_manager.compile( + range_entry.runnable = self.vllm_backend.compiler_manager.compile( self.graph, args, self.compilation_config.inductor_compile_config, self.compilation_config, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, - runtime_shape=runtime_shape, - ) + compile_range=range_entry.compile_range) # finished compilations for all required shapes - if self.is_last_graph and not self.to_be_compiled_sizes: + if (self.is_last_graph and not self.to_be_compiled_ranges): self.check_for_ending_compilation() - return entry.runnable(*args) + return range_entry.runnable(*args) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 31624a8fdcc0..78fd8386f56e 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -482,7 +482,7 @@ def __init__(self, config: VllmConfig): ).register(self.patterns) self.dump_patterns(config, self.patterns) - def is_applicable(self, shape: int | None) -> bool: + def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: # When sequence parallelism is enabled, the residual tensor from RMSNorm # needs to be split along the sequence dimension. However, this dimension # is symbolic during piecewise compilation, and splitting symbolic shapes @@ -502,7 +502,9 @@ def is_applicable(self, shape: int | None) -> bool: ): return True tp_size = get_tensor_model_parallel_world_size() - return shape is not None and shape % tp_size == 0 + return compile_range is not None and ( + compile_range[0] + == compile_range[1]) and (compile_range[1] % tp_size == 0) @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 72418762773c..374e1c99fea0 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -214,6 +214,8 @@ class CompilationConfig: - Inductor compilation: - [`use_inductor`][vllm.config.CompilationConfig.use_inductor] - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes] + - [`compile_ranges_split_points`] + [vllm.config.CompilationConfig.compile_ranges_split_points] - [`inductor_compile_config`] [vllm.config.CompilationConfig.inductor_compile_config] - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes] @@ -331,6 +333,16 @@ class CompilationConfig: """Sizes to compile for inductor. In addition to integers, it also supports "cudagraph_capture_sizes" to specify the sizes for cudagraph capture.""" + compile_ranges_split_points: Optional[list[int]] = None + """Split points that represent compile ranges for inductor. + The compile ranges are + [1, split_points[0]), + [split_points[0], split_points[1]), ..., + [split_points[-1], max_num_batched_tokens + 1). + Compile sizes are also used single element ranges: + [compile_sizes[i], compile_sizes[i] + 1). + """ + inductor_compile_config: dict = field(default_factory=dict) """Additional configurations for inductor. - None: use default configurations.""" @@ -914,3 +926,24 @@ def custom_op_log_check(self): enable_str, op, ) + + def get_compile_ranges(self) -> list[tuple[int, int]]: + """Get the compile ranges for the compilation config.""" + compile_ranges_split_points = self.compile_ranges_split_points + compile_ranges = [] + # max_num_batched_tokens + 1 + max_split_point = max(compile_ranges_split_points) + compile_sizes = set(self.compile_sizes) + split_points = sorted( + compile_sizes.union(set(self.compile_ranges_split_points))) + # filter out split points that are greater + # than max_num_batched_tokens + 1 + split_points = [x for x in split_points if x <= max_split_point] + for i, s in enumerate(split_points): + if i == 0: + compile_ranges.append((1, s)) + else: + compile_ranges.append((split_points[i - 1], s)) + if s in compile_sizes and s != 1: + compile_ranges.append((s, s)) + return sorted(compile_ranges) From 0cbb0656ac01d60fb3286e63550d215e95caed81 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 4 Sep 2025 10:00:52 -0700 Subject: [PATCH 085/137] Remove general shape graph Signed-off-by: ilmarkov --- vllm/compilation/backends.py | 14 +------ vllm/compilation/piecewise_backend.py | 53 +++++++++++++-------------- vllm/config/compilation.py | 2 + 3 files changed, 30 insertions(+), 39 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 686c415f7ac3..45a1a8c2f267 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -391,17 +391,7 @@ def call_module( sym_shape_indices = [ i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] - global compilation_start_time - compiled_graph_for_dynamic_shape = ( - self.vllm_backend.compiler_manager.compile( - submod, - args, - self.compilation_config.inductor_compile_config, - self.compilation_config, - graph_index=index, - num_graphs=len(self.compile_submod_names), - compile_range=None, - )) + # Lazy import here to avoid circular import from .piecewise_backend import PiecewiseBackend @@ -411,7 +401,7 @@ def call_module( index, len(self.compile_submod_names), sym_shape_indices, - compiled_graph_for_dynamic_shape, + # compiled_graph_for_dynamic_shape, self.vllm_backend, ) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 87b0121f43cb..d280b85fc82a 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -63,15 +63,12 @@ def __init__( self.first_run_finished = False - self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa - self.sym_shape_indices = sym_shape_indices # the entries for different shapes that we need to compile # self.concrete_size_entries: dict[int, RangeEntry] = {} # the entries for ranges that we need to either - # TODO: we should merge with concrete_size_entries self.range_entries: dict[tuple[int, int], RangeEntry] = {} # to_be_compiled_ranges tracks the remaining ranges to compile, @@ -81,10 +78,7 @@ def __init__( # We only keep compilation management inside this class directly. for range in self.compile_ranges: - self.range_entries[range] = RangeEntry( - compile_range=range, - runnable=self.compiled_graph_for_general_shape, - ) + self.range_entries[range] = RangeEntry(compile_range=range, ) def check_for_ending_compilation(self): if (self.is_last_graph and not self.to_be_compiled_ranges): @@ -93,24 +87,8 @@ def check_for_ending_compilation(self): self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) - def __call__(self, *args) -> Any: - if not self.first_run_finished: - self.first_run_finished = True - self.check_for_ending_compilation() - return self.compiled_graph_for_general_shape(*args) - - runtime_shape = args[self.sym_shape_indices[0]] - - range_entry = None - for range in self.compile_ranges: - if self.is_in_range(runtime_shape, range): - range_entry = self.range_entries[range] - break - - if (range_entry is None): - # we don't need to do anything for this shape - return self.compiled_graph_for_general_shape(*args) - + def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, + args) -> Any: if not range_entry.compiled: range_entry.compiled = True self.to_be_compiled_ranges.remove(range_entry.compile_range) @@ -126,7 +104,28 @@ def __call__(self, *args) -> Any: compile_range=range_entry.compile_range) # finished compilations for all required shapes - if (self.is_last_graph and not self.to_be_compiled_ranges): - self.check_for_ending_compilation() + self.check_for_ending_compilation() + + def __call__(self, *args) -> Any: + if not self.first_run_finished: + self.first_run_finished = True + + # Role of the general is taken by the last range + range_entry = self.range_entries[self.compile_ranges[-1]] + self._maybe_compile_for_range_entry(range_entry, args) + return range_entry.runnable(*args) + + runtime_shape = args[self.sym_shape_indices[0]] + + range_entry = None + for range in self.compile_ranges: + if self.is_in_range(runtime_shape, range): + range_entry = self.range_entries[range] + break + assert range_entry is not None, \ + f"Shape out of considered range: {runtime_shape} " \ + "[1, max_num_batched_tokens]" + + self._maybe_compile_for_range_entry(range_entry, args) return range_entry.runnable(*args) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 374e1c99fea0..2aab5cb5f295 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -946,4 +946,6 @@ def get_compile_ranges(self) -> list[tuple[int, int]]: compile_ranges.append((split_points[i - 1], s)) if s in compile_sizes and s != 1: compile_ranges.append((s, s)) + assert compile_ranges[-1][1] == max_split_point, \ + "Last compile range end should be max_split_point" return sorted(compile_ranges) From d5392f54cb6e8f15926f1d89642ad08cda44a99c Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Fri, 5 Sep 2025 06:00:15 -0700 Subject: [PATCH 086/137] Add test to test pipeline Signed-off-by: ilmarkov --- .buildkite/test-pipeline.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6cbc25b4b3bf..105eca371ff3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -412,6 +412,7 @@ steps: - pytest -v -s compile/test_decorator.py - pytest -v -s compile/test_noop_elimination.py - pytest -v -s compile/test_aot_compile.py + - pytest -v -s compile/test_compile_ranges.py - label: PyTorch Fullgraph Smoke Test # 15min timeout_in_minutes: 30 From 027c9eb348808e1a37c9dbc86fbfcd020e2166a8 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 9 Sep 2025 05:32:05 -0700 Subject: [PATCH 087/137] Fix pre-commit Signed-off-by: ilmarkov --- vllm/compilation/piecewise_backend.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index d280b85fc82a..cec8aca63d80 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -117,12 +117,13 @@ def __call__(self, *args) -> Any: runtime_shape = args[self.sym_shape_indices[0]] - range_entry = None + range_found = False for range in self.compile_ranges: if self.is_in_range(runtime_shape, range): range_entry = self.range_entries[range] + range_found = True break - assert range_entry is not None, \ + assert range_found, \ f"Shape out of considered range: {runtime_shape} " \ "[1, max_num_batched_tokens]" From b2992d3b9afa19156df1453fa504df87ecbc30d9 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 16 Oct 2025 20:12:17 +0000 Subject: [PATCH 088/137] Upd Signed-off-by: ilmarkov --- tests/compile/test_compile_ranges.py | 48 ++++++++-------- vllm/compilation/backends.py | 73 ++++++++++++++---------- vllm/compilation/collective_fusion.py | 19 +++--- vllm/compilation/compiler_interface.py | 16 +++--- vllm/compilation/inductor_pass.py | 3 +- vllm/compilation/pass_manager.py | 2 +- vllm/compilation/piecewise_backend.py | 30 +++++----- vllm/compilation/sequence_parallelism.py | 8 ++- vllm/config/compilation.py | 8 ++- 9 files changed, 114 insertions(+), 93 deletions(-) diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py index 6759da199f4b..68389ccfbe14 100644 --- a/tests/compile/test_compile_ranges.py +++ b/tests/compile/test_compile_ranges.py @@ -6,8 +6,12 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, - set_current_vllm_config) +from vllm.config import ( + CompilationConfig, + CompilationLevel, + VllmConfig, + set_current_vllm_config, +) from vllm.forward_context import set_forward_context from vllm.utils import direct_register_custom_op @@ -18,15 +22,17 @@ MLP_SIZE = 128 -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: +def silly_attention( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor +) -> None: out.copy_(q) out += k out += v -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: +def silly_attention_fake( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor +) -> None: return @@ -41,12 +47,7 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @support_torch_compile class TestModel(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -59,8 +60,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @torch.inference_mode -def run_model(vllm_config: VllmConfig, model: nn.Module, - batch_sizes: list[int]): +def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]): with set_forward_context({}, vllm_config=vllm_config): model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) for batch_size in batch_sizes: @@ -68,19 +68,21 @@ def run_model(vllm_config: VllmConfig, model: nn.Module, def test_compile_ranges(): - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - compile_ranges_split_points=[8, 32], - )) + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + compile_ranges_split_points=[8, 32], + ) + ) with set_current_vllm_config(vllm_config): - model = TestModel(vllm_config=vllm_config, prefix='').eval().cuda() + model = TestModel(vllm_config=vllm_config, prefix="").eval().cuda() batch_sizes = [1, 16, 48] # A has support_torch_compile with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=1, - num_backend_compilations=4, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=1, + num_piecewise_graphs_seen=1, + num_backend_compilations=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): run_model(vllm_config, model, batch_sizes) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 45a1a8c2f267..beda9b36f686 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -80,8 +80,7 @@ class CompilerManager: """ def __init__(self, compilation_config: CompilationConfig): - self.cache: dict[tuple[tuple[int, int] | None, int, str], - Any] = (dict()) + self.cache: dict[tuple[tuple[int, int] | None, int, str], Any] = dict() self.is_cache_updated = False self.compilation_config = compilation_config self.compiler = make_compiler(compilation_config) @@ -156,20 +155,26 @@ def load( if (compile_range, graph_index, self.compiler.name) not in self.cache: return None handle = self.cache[(compile_range, graph_index, self.compiler.name)] - compiled_graph = self.compiler.load(handle, graph, example_inputs, - graph_index, compile_range) + compiled_graph = self.compiler.load( + handle, graph, example_inputs, graph_index, compile_range + ) if compile_range is None: logger.debug( - "Directly load the %s-th graph for dynamic compile range from %s via handle %s", + "Directly load the %s-th graph for dynamic compile range" + "from %s via handle %s", graph_index, self.compiler.name, handle, ) else: logger.debug( - "Directly load the %s-th graph for compile range %s from %s via " - "handle %s", graph_index, str(compile_range), - self.compiler.name, handle) + "Directly load the %s-th graph for compile range %s" + "from %s via handle %s", + graph_index, + str(compile_range), + self.compiler.name, + handle, + ) return compiled_graph def compile( @@ -192,8 +197,7 @@ def compile( compiled_graph = None # try to load from the cache - compiled_graph = self.load(graph, example_inputs, graph_index, - compile_range) + compiled_graph = self.load(graph, example_inputs, graph_index, compile_range) if compiled_graph is not None: if graph_index == num_graphs - 1: # after loading the last graph for this shape, record the time. @@ -209,8 +213,10 @@ def compile( else: logger.info( "Directly load the compiled graph(s) for compile range %s " - "from the cache, took %.3f s", str(compile_range), - elapsed) + "from the cache, took %.3f s", + str(compile_range), + elapsed, + ) return compiled_graph # no compiler cached the graph, or the cache is disabled, @@ -219,38 +225,43 @@ def compile( # Let compile_fx generate a key for us maybe_key = None else: - maybe_key = \ - f"artifact_compile_range_{compile_range}_subgraph_{graph_index}" - compiled_graph, handle = self.compiler.compile( - graph, example_inputs, additional_inductor_config, compile_range, - maybe_key) + maybe_key = f"artifact_compile_range_{compile_range}_subgraph_{graph_index}" + with self.compile_context(compile_range): + compiled_graph, handle = self.compiler.compile( + graph, + example_inputs, + additional_inductor_config, + compile_range, + maybe_key, + ) assert compiled_graph is not None, "Failed to compile the graph" # store the artifact in the cache if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None: - self.cache[(compile_range, graph_index, - self.compiler.name)] = handle + self.cache[(compile_range, graph_index, self.compiler.name)] = handle compilation_counter.num_cache_entries_updated += 1 self.is_cache_updated = True if graph_index == 0: # adds some info logging for the first graph if compile_range is None: - logger.info_once( - "Cache the graph for dynamic shape for later use") + logger.info_once("Cache the graph for dynamic shape for later use", scope="local") else: - logger.info_once("Cache the graph of compile range %s for later use", - str(compile_range)) + logger.info_once( + "Cache the graph of compile range %s for later use", + str(compile_range), + ) if compile_range is None: logger.debug( - "Store the %s-th graph for dynamic compile range from %s via handle %s", + "Store the %s-th graph for dynamic compile range" + "from %s via handle %s", graph_index, self.compiler.name, handle, ) else: logger.debug( - "Store the %s-th graph for compile range %s from %s via handle %s", + "Store the %s-th graph for compile range%s from %s via handle %s", graph_index, str(compile_range), self.compiler.name, @@ -264,14 +275,17 @@ def compile( compilation_config.compilation_time += elapsed if compile_range is None: logger.info_once( - "Compiling a graph for dynamic compile range takes %.2f s", - + "Compiling a graph for dynamic compile range takes %.2f s", elapsed, scope="local", ) else: - logger.info_once("Compiling a graph for compile range %s takes %.2f s", - str(compile_range), elapsed, scope="local") + logger.info_once( + "Compiling a graph for compile range %s takes %.2f s", + str(compile_range), + elapsed, + scope="local", + ) return compiled_graph @@ -401,7 +415,6 @@ def call_module( index, len(self.compile_submod_names), sym_shape_indices, - # compiled_graph_for_dynamic_shape, self.vllm_backend, ) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index a4758c971611..3d970ac2964b 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -505,12 +505,12 @@ def call_trtllm_fused_allreduce_norm( element_size = allreduce_in.element_size() current_tensor_size = num_tokens * hidden_size * element_size max_tensor_size = max_token_num * hidden_size * element_size - assert current_tensor_size <= max_tensor_size, \ - f"Current tensor size {current_tensor_size} is larger than " \ - f"max token num {max_token_num} * hidden size {hidden_size} * " \ + assert current_tensor_size <= max_tensor_size, ( + f"Current tensor size {current_tensor_size} is larger than " + f"max token num {max_token_num} * hidden size {hidden_size} * " f"element size {element_size}" - device_capability = current_platform.get_device_capability( - ).as_version_str() + ) + device_capability = current_platform.get_device_capability().as_version_str() # Get one shot input size limit for the current world size # for the current device capability max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB. \ @@ -520,9 +520,9 @@ def call_trtllm_fused_allreduce_norm( use_oneshot = max_one_shot_size is None or \ current_tensor_size <= max_one_shot_size * MiB - assert ( - _FI_WORKSPACE_TENSOR - is not None), "Flashinfer must be enabled when using flashinfer" + assert _FI_WORKSPACE_TENSOR is not None, ( + "Flashinfer must be enabled when using flashinfer" + ) if norm_out is None: norm_out = allreduce_in residual_out = residual @@ -1178,8 +1178,7 @@ def register_patterns(self): self.disabled = False @VllmInductorPass.time_and_log - def is_applicable_for_range( - self, compile_range: tuple[int, int] | None) -> bool: + def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: if compile_range is None: return False return compile_range[1] - 1 <= self.max_token_num diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 3861bfed11d5..4e5aa077ddae 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -63,14 +63,14 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int | None] = None, + compile_range: tuple[int, int] | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: """ Compile the graph with the given example inputs and compiler config, with a range. If the `compile_range` is None, it means the `example_inputs` have a dynamic shape. Otherwise, the - `compile_range` specifies the range of the inputs, + `compile_range` specifies the range of the inputs, it could be concrete size, e.g. (4, 4). Right now we only support one variable range of shapes for all inputs, which is the batchsize (number of tokens) during inference. @@ -99,7 +99,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: tuple[int, int | None] = None, + compile_range: tuple[int, int] | None = None, ) -> Callable: """ Load the compiled function from the handle. @@ -193,7 +193,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int | None] = None, + compile_range: tuple[int, int] | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -234,7 +234,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: tuple[int, int | None] = None, + compile_range: tuple[int, int] | None = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -298,7 +298,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int | None] = None, + compile_range: tuple[int, int] | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -497,7 +497,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: tuple[int, int | None] = None, + compile_range: tuple[int, int] | None = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -615,7 +615,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int | None] = None, + compile_range: tuple[int, int] | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_eager_compiles += 1 diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 1b4430c82b2d..599fa776b6c0 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -96,8 +96,7 @@ def hash_dict(dict_: dict[Any, Any]): encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() - def is_applicable_for_range(self, compile_range: tuple[int, - int] | None): + def is_applicable_for_range(self, compile_range: tuple[int, int] | None): return True diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 82bca8f1fe1b..08002dc862f6 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -75,7 +75,7 @@ def __call__(self, graph: fx.Graph): pass_(graph) VllmInductorPass.dump_prefix += 1 else: - logger.debug("Skipping %s with shape %s", pass_, shape) + logger.debug("Skipping %s with compile range %s", pass_, compile_range) # post-cleanup goes before fix_functionalization # because it requires a functional graph diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index cec8aca63d80..607d6a80f5cf 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -30,7 +30,6 @@ def __init__( piecewise_compile_index: int, total_piecewise_compiles: int, sym_shape_indices: list[int], - compiled_graph_for_general_shape: Callable, vllm_backend: VllmBackend, ): """ @@ -58,8 +57,11 @@ def __init__( log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}" logger.debug_once(log_string) - self.is_in_range = lambda x, range: range[0] <= x < range[1] if range[ - 0] < range[1] else x == range[0] + self.is_in_range = ( + lambda x, range: range[0] <= x < range[1] + if range[0] < range[1] + else x == range[0] + ) self.first_run_finished = False @@ -73,22 +75,22 @@ def __init__( # to_be_compiled_ranges tracks the remaining ranges to compile, # and updates during the compilation process, so we need to copy it - self.to_be_compiled_ranges: set[tuple[int, - int]] = set(self.compile_ranges) + self.to_be_compiled_ranges: set[tuple[int, int]] = set(self.compile_ranges) # We only keep compilation management inside this class directly. for range in self.compile_ranges: - self.range_entries[range] = RangeEntry(compile_range=range, ) + self.range_entries[range] = RangeEntry( + compile_range=range, + ) def check_for_ending_compilation(self): - if (self.is_last_graph and not self.to_be_compiled_ranges): + if self.is_last_graph and not self.to_be_compiled_ranges: # no specific sizes to compile # save the hash of the inductor graph for the next run self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) - def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, - args) -> Any: + def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: if not range_entry.compiled: range_entry.compiled = True self.to_be_compiled_ranges.remove(range_entry.compile_range) @@ -101,7 +103,8 @@ def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, self.compilation_config, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, - compile_range=range_entry.compile_range) + compile_range=range_entry.compile_range, + ) # finished compilations for all required shapes self.check_for_ending_compilation() @@ -123,9 +126,10 @@ def __call__(self, *args) -> Any: range_entry = self.range_entries[range] range_found = True break - assert range_found, \ - f"Shape out of considered range: {runtime_shape} " \ - "[1, max_num_batched_tokens]" + assert range_found, ( + f"Shape out of considered range: {runtime_shape} " + "[1, max_num_batched_tokens]" + ) self._maybe_compile_for_range_entry(range_entry, args) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 78fd8386f56e..cf47adb4670a 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -502,9 +502,11 @@ def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool ): return True tp_size = get_tensor_model_parallel_world_size() - return compile_range is not None and ( - compile_range[0] - == compile_range[1]) and (compile_range[1] % tp_size == 0) + return ( + compile_range is not None + and (compile_range[0] == compile_range[1]) + and (compile_range[1] % tp_size == 0) + ) @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 2aab5cb5f295..278fe5801323 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -333,7 +333,7 @@ class CompilationConfig: """Sizes to compile for inductor. In addition to integers, it also supports "cudagraph_capture_sizes" to specify the sizes for cudagraph capture.""" - compile_ranges_split_points: Optional[list[int]] = None + compile_ranges_split_points: list[int] | None = None """Split points that represent compile ranges for inductor. The compile ranges are [1, split_points[0]), @@ -935,7 +935,8 @@ def get_compile_ranges(self) -> list[tuple[int, int]]: max_split_point = max(compile_ranges_split_points) compile_sizes = set(self.compile_sizes) split_points = sorted( - compile_sizes.union(set(self.compile_ranges_split_points))) + compile_sizes.union(set(self.compile_ranges_split_points)) + ) # filter out split points that are greater # than max_num_batched_tokens + 1 split_points = [x for x in split_points if x <= max_split_point] @@ -946,6 +947,7 @@ def get_compile_ranges(self) -> list[tuple[int, int]]: compile_ranges.append((split_points[i - 1], s)) if s in compile_sizes and s != 1: compile_ranges.append((s, s)) - assert compile_ranges[-1][1] == max_split_point, \ + assert compile_ranges[-1][1] == max_split_point, ( "Last compile range end should be max_split_point" + ) return sorted(compile_ranges) From 3499384c1e183cd851c93d12ea7d77c08de03ed2 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 16 Oct 2025 20:32:36 +0000 Subject: [PATCH 089/137] Upd config Signed-off-by: ilmarkov --- vllm/config/vllm.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 916f258d6586..fd38992e374b 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -426,6 +426,8 @@ def __post_init__(self): "correctness and to realize prefill savings. " ) + self._set_compile_ranges() + disable_chunked_prefill_reasons: list[str] = [] if self.model_config: @@ -796,6 +798,49 @@ def _set_cudagraph_sizes(self): # complete the remaining process. self.compilation_config.post_init_cudagraph_sizes() + def _set_compile_ranges(self): + """ + Set the compile ranges for the compilation config. + """ + compilation_config = self.compilation_config + computed_compile_ranges_split_points = [] + + # The upper bound of the compile ranges is the max_num_batched_tokens + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + if max_num_batched_tokens is not None: + # We add 1 because the bounds checks in the compiler are exclusive + # and we want to include the max_num_batched_tokens + # in the compile range + computed_compile_ranges_split_points.append(max_num_batched_tokens + 1) + + # Add the compile ranges for flashinfer + if compilation_config.pass_config.enable_fi_allreduce_fusion: + tp_size = self.parallel_config.tensor_parallel_size + max_size = compilation_config.pass_config.flashinfer_max_size(tp_size) + if max_size is not None: + max_token_num = max_size // ( + self.model_config.get_hidden_size() + * self.model_config.dtype.itemsize + ) + # We add 1 because the bounds checks in the compiler are + # exclusive and we want to include the max_token_num in the + # compile range + computed_compile_ranges_split_points.append(max_token_num + 1) + + if compilation_config.compile_ranges_split_points is not None: + for x in compilation_config.compile_ranges_split_points: + assert isinstance(x, int) + assert x > 0, f"Invalid compile range split point: {x}" + if ( + max_num_batched_tokens is not None + and x < max_num_batched_tokens + and x > 1 + ): + computed_compile_ranges_split_points.append(x) + compilation_config.compile_ranges_split_points = sorted( + computed_compile_ranges_split_points + ) # type: ignore + def recalculate_max_model_len(self, max_model_len: int): # Can only be called in try_verify_and_update_config model_config = self.model_config From 5336ee6ffe1d5b03b69b23f4b346ba10a549c6cd Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 16 Oct 2025 20:51:01 +0000 Subject: [PATCH 090/137] Fix Signed-off-by: ilmarkov --- vllm/compilation/collective_fusion.py | 18 ++++++++++-------- vllm/v1/worker/utils.py | 2 +- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 3d970ac2964b..7c0a1208d870 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -431,7 +431,7 @@ def __init__(self, config: VllmConfig): self.dump_patterns(config, self.patterns) - def is_applicable(self, shape: int | None) -> bool: + def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: # This pass is applied on top of the sequence parallelism pass. # It inherits the same applicability condition as `SequenceParallelismPass`. # See `SequenceParallelismPass.is_applicable` for more details. @@ -441,7 +441,9 @@ def is_applicable(self, shape: int | None) -> bool: ): return True tp_size = get_tensor_model_parallel_world_size() - return shape is not None and shape % tp_size == 0 + return compile_range is not None and ( + compile_range[0] == compile_range[1] and compile_range[1] % tp_size == 0 + ) @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): @@ -1100,18 +1102,18 @@ def __init__(self, config: VllmConfig): ) return element_size = 4 if use_fp32_lamport else 2 - max_token_num = max_size // (self.hidden_dim * element_size) + self.max_token_num = max_size // (self.hidden_dim * element_size) # take the min to save workspace size and we'll never use more # than max_num_batched_tokens anyways - max_token_num = min( - max_token_num, config.scheduler_config.max_num_batched_tokens + self.max_token_num = min( + self.max_token_num, config.scheduler_config.max_num_batched_tokens ) self.ipc_handles, workspace_tensor = ( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( tp_rank=rank, tp_size=self.tp_size, - max_token_num=max_token_num, + max_token_num=self.max_token_num, hidden_dim=self.hidden_dim, group=self.group, use_fp32_lamport=use_fp32_lamport, @@ -1124,7 +1126,7 @@ def __init__(self, config: VllmConfig): rank=rank, world_size=self.tp_size, use_fp32_lamport=use_fp32_lamport, - max_token_num=max_token_num, + max_token_num=self.max_token_num, ) self.register_patterns() @@ -1177,12 +1179,12 @@ def register_patterns(self): self.disabled = False - @VllmInductorPass.time_and_log def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: if compile_range is None: return False return compile_range[1] - 1 <= self.max_token_num + @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): if self.disabled: logger.debug("AllReduceFusionPass disabled") diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 92baf0cb7136..ef953dd2051e 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -330,7 +330,7 @@ def is_residual_scattered_for_sp( The residual tensor is scattered across tensor parallel ranks when sequence parallelism and tensor parallelism is enabled. - This follows the same logic as SequenceParallelismPass.is_applicable(): + This follows the same logic as SequenceParallelismPass.is_applicable_for_range(): - In full-graph compilation mode (no splitting ops or using inductor graph partition), SP is always applied - Otherwise, SP is only applied for specific shapes in compile_sizes From 4958474f77a930f532730a9ec7a395339ea32138 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Fri, 17 Oct 2025 11:30:21 +0000 Subject: [PATCH 091/137] Priotitize compile_sizes Signed-off-by: ilmarkov --- vllm/compilation/piecewise_backend.py | 28 ++++++++++++++++++++------- vllm/config/compilation.py | 18 ++--------------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 607d6a80f5cf..7a10fed1d237 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -57,6 +57,10 @@ def __init__( log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}" logger.debug_once(log_string) + self.compile_sizes = self.compilation_config.compile_sizes + log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}" + logger.debug_once(log_string) + self.is_in_range = ( lambda x, range: range[0] <= x < range[1] if range[0] < range[1] @@ -78,6 +82,12 @@ def __init__( self.to_be_compiled_ranges: set[tuple[int, int]] = set(self.compile_ranges) # We only keep compilation management inside this class directly. + for size in self.compile_sizes: + range = (size, size) + self.range_entries[range] = RangeEntry( + compile_range=range, + ) + for range in self.compile_ranges: self.range_entries[range] = RangeEntry( compile_range=range, @@ -112,20 +122,24 @@ def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: def __call__(self, *args) -> Any: if not self.first_run_finished: self.first_run_finished = True + self.check_for_ending_compilation() - # Role of the general is taken by the last range + # Role of the general graph is taken by the last range graph range_entry = self.range_entries[self.compile_ranges[-1]] self._maybe_compile_for_range_entry(range_entry, args) return range_entry.runnable(*args) - runtime_shape = args[self.sym_shape_indices[0]] range_found = False - for range in self.compile_ranges: - if self.is_in_range(runtime_shape, range): - range_entry = self.range_entries[range] - range_found = True - break + if runtime_shape in self.compile_sizes: + range_entry = self.range_entries[(runtime_shape, runtime_shape)] + range_found = True + else: + for range in self.compile_ranges: + if self.is_in_range(runtime_shape, range): + range_entry = self.range_entries[range] + range_found = True + break assert range_found, ( f"Shape out of considered range: {runtime_shape} " "[1, max_num_batched_tokens]" diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 278fe5801323..c2a6d6d783b9 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -929,25 +929,11 @@ def custom_op_log_check(self): def get_compile_ranges(self) -> list[tuple[int, int]]: """Get the compile ranges for the compilation config.""" - compile_ranges_split_points = self.compile_ranges_split_points + split_points = self.compile_ranges_split_points compile_ranges = [] - # max_num_batched_tokens + 1 - max_split_point = max(compile_ranges_split_points) - compile_sizes = set(self.compile_sizes) - split_points = sorted( - compile_sizes.union(set(self.compile_ranges_split_points)) - ) - # filter out split points that are greater - # than max_num_batched_tokens + 1 - split_points = [x for x in split_points if x <= max_split_point] for i, s in enumerate(split_points): if i == 0: compile_ranges.append((1, s)) else: compile_ranges.append((split_points[i - 1], s)) - if s in compile_sizes and s != 1: - compile_ranges.append((s, s)) - assert compile_ranges[-1][1] == max_split_point, ( - "Last compile range end should be max_split_point" - ) - return sorted(compile_ranges) + return compile_ranges From 04306ed0dacf3fc11bcfb5ae993095d8d5a506bb Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 28 Oct 2025 13:26:59 +0000 Subject: [PATCH 092/137] Fix inductor config Signed-off-by: ilmarkov --- vllm/compilation/backends.py | 7 ++++++- vllm/compilation/compiler_interface.py | 4 ++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index beda9b36f686..30ab91e4ab82 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -225,7 +225,12 @@ def compile( # Let compile_fx generate a key for us maybe_key = None else: - maybe_key = f"artifact_compile_range_{compile_range}_subgraph_{graph_index}" + maybe_key = "artifact_compile_range_" + if compile_range is None: + maybe_key += "dynamic_shape" + else: + maybe_key += f"{compile_range[0]}_{compile_range[1]}" + maybe_key += f"_subgraph_{graph_index}" with self.compile_context(compile_range): compiled_graph, handle = self.compiler.compile( graph, diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 4e5aa077ddae..d069769fe76f 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -594,8 +594,8 @@ def metrics_context(self) -> contextlib.AbstractContextManager: def set_inductor_config(config, compile_range): - if isinstance(compile_range, tuple): - # for a specific range of batchsizes, tuning triton kernel parameters + if isinstance(compile_range, tuple) and compile_range[0] == compile_range[1]: + # for a specific batch size, tuning triton kernel parameters # can be beneficial config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE config["coordinate_descent_tuning"] = ( From 9dc4eea25b0ec2520d920616002a6f148a1c3801 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 3 Nov 2025 10:53:49 +0000 Subject: [PATCH 093/137] Laith's fix Signed-off-by: ilmarkov --- vllm/compilation/compiler_interface.py | 38 +++++++++++++++++++++----- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index d069769fe76f..3453b8f676e8 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -213,13 +213,37 @@ def compile( from torch._inductor import standalone_compile - compiled_graph = standalone_compile( - graph, - example_inputs, - dynamic_shapes=dynamic_shapes, - options={"config_patches": current_config}, - ) - + if dynamic_shapes == "from_graph": + # We need to pass fake example_inputs, otherwise torch.compile + # will fakify the example_inputs potentially causing some non dynamic + # dimension to be be duck shaped to other existing shapes that have hints + # matching their values. + # This is problem because it can lead to unintended specializations! + # if the new wrongly dynamic dim is specialized + # it will force specializing the whole shape + # standalone_compile probably should not accept + # non fake tensors as example inputs! + fake_example_inputs = [] + for node in graph.graph.nodes: + # All place holders come first + if node.op == "placeholder": + fake_example_inputs.append(node.meta["example_value"]) + else: + break + assert len(fake_example_inputs) == len(example_inputs) + compiled_graph = standalone_compile( + graph, + fake_example_inputs, + dynamic_shapes=dynamic_shapes, + options={"config_patches": current_config}, + ) + else: + compiled_graph = standalone_compile( + graph, + example_inputs, + dynamic_shapes=dynamic_shapes, + options={"config_patches": current_config}, + ) # Save the compiled artifact to disk in the specified path assert key is not None path = os.path.join(self.cache_dir, key) From 2c63f0b05c02ce4d93e23093b3838af775d92614 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 4 Nov 2025 10:22:17 +0000 Subject: [PATCH 094/137] Upd Signed-off-by: ilmarkov --- vllm/compilation/backends.py | 6 ++++-- vllm/compilation/collective_fusion.py | 11 ++++++----- vllm/config/compilation.py | 3 +++ 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 30ab91e4ab82..7cda5d0dee96 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -250,7 +250,9 @@ def compile( if graph_index == 0: # adds some info logging for the first graph if compile_range is None: - logger.info_once("Cache the graph for dynamic shape for later use", scope="local") + logger.info_once( + "Cache the graph for dynamic shape for later use", scope="local" + ) else: logger.info_once( "Cache the graph of compile range %s for later use", @@ -280,7 +282,7 @@ def compile( compilation_config.compilation_time += elapsed if compile_range is None: logger.info_once( - "Compiling a graph for dynamic compile range takes %.2f s", + "Compiling a graph for dynamic compile range takes %.2f s", elapsed, scope="local", ) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 7c0a1208d870..9c20db07c267 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -515,12 +515,13 @@ def call_trtllm_fused_allreduce_norm( device_capability = current_platform.get_device_capability().as_version_str() # Get one shot input size limit for the current world size # for the current device capability - max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB. \ - get(device_capability, {}). \ - get(world_size, None) + max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get( + device_capability, {} + ).get(world_size, None) # Use one shot if no max size is specified - use_oneshot = max_one_shot_size is None or \ - current_tensor_size <= max_one_shot_size * MiB + use_oneshot = ( + max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB + ) assert _FI_WORKSPACE_TENSOR is not None, ( "Flashinfer must be enabled when using flashinfer" diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index c2a6d6d783b9..e469c8e25a43 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -142,6 +142,9 @@ def flashinfer_max_size(self, world_size: int) -> int | None: max_sizes = { k: int(v * MiB) for k, v in self.fi_allreduce_fusion_max_size_mb.items() } + logger.debug_once( + f"flashinfer_max_size: {max_sizes.get(world_size)}", scope="global" + ) # return None if world size is not supported by flashinfer return max_sizes.get(world_size) From 67f7ae18e39bc7834581fb0879e431e2734cbf29 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 4 Nov 2025 10:28:36 +0000 Subject: [PATCH 095/137] Update config Signed-off-by: ilmarkov --- vllm/config/compilation.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 72418762773c..d25aa23131e6 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -110,12 +110,12 @@ class PassConfig: """Whether to enable async TP.""" enable_fi_allreduce_fusion: bool = False """Whether to enable flashinfer allreduce fusion.""" - fi_allreduce_fusion_max_size_mb: dict[int, float] = field(default_factory=dict) - """The thresholds of the communicated tensor sizes under which + fi_allreduce_fusion_max_size_mb: float | None = None + """The threshold of the communicated tensor sizes under which vllm should use flashinfer fused allreduce. Specified as a - dictionary mapping each world size to the threshold in MB - { : } - Unspecified world sizes will fallback to + float in MB. + Unspecified will fallback to default values + which are compute capability and world size dependent. FI_ALLREDUCE_FUSION_MAX_SIZE_MB = { "9.0": { 2: 64, # 64MB @@ -139,12 +139,11 @@ def flashinfer_max_size(self, world_size: int) -> int | None: """ MiB = 1024 * 1024 - max_sizes = { - k: int(v * MiB) for k, v in self.fi_allreduce_fusion_max_size_mb.items() - } + max_size_mb = self.fi_allreduce_fusion_max_size_mb + if max_size_mb is None: + max_size_mb = self.default_fi_allreduce_fusion_max_size_mb().get(world_size) - # return None if world size is not supported by flashinfer - return max_sizes.get(world_size) + return int(max_size_mb * MiB) if max_size_mb is not None else None @staticmethod def default_fi_allreduce_fusion_max_size_mb() -> dict[int, float]: @@ -181,11 +180,6 @@ def __post_init__(self) -> None: "Allreduce + rms norm + quant (fp8) fusion might not work" ) - self.fi_allreduce_fusion_max_size_mb = { - **PassConfig.default_fi_allreduce_fusion_max_size_mb(), - **self.fi_allreduce_fusion_max_size_mb, - } - @config @dataclass From fcebc21fb1708abbfc2622cfeee517aef801c622 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 4 Nov 2025 14:30:18 +0000 Subject: [PATCH 096/137] Add caching Signed-off-by: ilmarkov --- vllm/compilation/compiler_interface.py | 37 +++++--------------------- vllm/compilation/pass_manager.py | 1 + vllm/compilation/piecewise_backend.py | 23 +++++++++++++++- vllm/config/compilation.py | 8 +++--- 4 files changed, 32 insertions(+), 37 deletions(-) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 3453b8f676e8..6a57cd4bc578 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -213,37 +213,12 @@ def compile( from torch._inductor import standalone_compile - if dynamic_shapes == "from_graph": - # We need to pass fake example_inputs, otherwise torch.compile - # will fakify the example_inputs potentially causing some non dynamic - # dimension to be be duck shaped to other existing shapes that have hints - # matching their values. - # This is problem because it can lead to unintended specializations! - # if the new wrongly dynamic dim is specialized - # it will force specializing the whole shape - # standalone_compile probably should not accept - # non fake tensors as example inputs! - fake_example_inputs = [] - for node in graph.graph.nodes: - # All place holders come first - if node.op == "placeholder": - fake_example_inputs.append(node.meta["example_value"]) - else: - break - assert len(fake_example_inputs) == len(example_inputs) - compiled_graph = standalone_compile( - graph, - fake_example_inputs, - dynamic_shapes=dynamic_shapes, - options={"config_patches": current_config}, - ) - else: - compiled_graph = standalone_compile( - graph, - example_inputs, - dynamic_shapes=dynamic_shapes, - options={"config_patches": current_config}, - ) + compiled_graph = standalone_compile( + graph, + example_inputs, + dynamic_shapes=dynamic_shapes, + options={"config_patches": current_config}, + ) # Save the compiled artifact to disk in the specified path assert key is not None path = os.path.join(self.cache_dir, key) diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 08002dc862f6..3e0c9bc99a24 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -155,5 +155,6 @@ def uuid(self): # See [HACK: Bug with Inductor graph partition and torch.compile cache] state["inductor_splitting_ops"].extend(self.inductor_splitting_ops) + state["compile_range"] = get_pass_context().compile_range return InductorPass.hash_dict(state) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 7a10fed1d237..ad5b49f28550 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -87,6 +87,7 @@ def __init__( self.range_entries[range] = RangeEntry( compile_range=range, ) + self.to_be_compiled_ranges.add(range) for range in self.compile_ranges: self.range_entries[range] = RangeEntry( @@ -100,6 +101,26 @@ def check_for_ending_compilation(self): self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) + def fakify_args(self, args: list[Any]) -> list[Any]: + # We need to pass fake example_inputs, otherwise torch.compile + # will fakify the example_inputs potentially causing some non dynamic + # dimension to be be duck shaped to other existing shapes that have hints + # matching their values. + # This is problem because it can lead to unintended specializations! + # if the new wrongly dynamic dim is specialized + # it will force specializing the whole shape + # torch.compile probably should not accept + # non fake tensors as example inputs! + fake_example_inputs = [] + for node in self.graph.graph.nodes: + # All place holders come first + if node.op == "placeholder": + fake_example_inputs.append(node.meta["example_value"]) + else: + break + assert len(fake_example_inputs) == len(args) + return fake_example_inputs + def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: if not range_entry.compiled: range_entry.compiled = True @@ -108,7 +129,7 @@ def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: # args are real arguments range_entry.runnable = self.vllm_backend.compiler_manager.compile( self.graph, - args, + self.fakify_args(args), self.compilation_config.inductor_compile_config, self.compilation_config, graph_index=self.piecewise_compile_index, diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 475f4c15afef..fa728c23d145 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -142,11 +142,9 @@ def flashinfer_max_size(self, world_size: int) -> int | None: max_size_mb = self.fi_allreduce_fusion_max_size_mb if max_size_mb is None: max_size_mb = self.default_fi_allreduce_fusion_max_size_mb().get(world_size) - logger.debug_once( - f"flashinfer_max_size: {int(max_size_mb * MiB)}", scope="global" - ) - return int(max_size_mb * MiB) - return None + max_size_bytes = int(max_size_mb * MiB) if max_size_mb is not None else None + logger.debug_once(f"flashinfer_max_size: {max_size_bytes}", scope="global") + return max_size_bytes @staticmethod def default_fi_allreduce_fusion_max_size_mb() -> dict[int, float]: From 65151bcecf8429890f4fa191e7988aedfb2c9aa5 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 5 Nov 2025 12:58:20 +0000 Subject: [PATCH 097/137] Address comments Signed-off-by: ilmarkov --- tests/compile/test_compile_ranges.py | 65 +++++++++++++++------------ vllm/compilation/collective_fusion.py | 5 +++ vllm/config/compilation.py | 1 - 3 files changed, 41 insertions(+), 30 deletions(-) diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py index 68389ccfbe14..03f31df1ece7 100644 --- a/tests/compile/test_compile_ranges.py +++ b/tests/compile/test_compile_ranges.py @@ -1,19 +1,24 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch +from torch import fx as fx from torch import nn from torch.library import Library from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile +from vllm.compilation.inductor_pass import ( + CustomGraphPass, + InductorPass, + get_pass_context, +) from vllm.config import ( - CompilationConfig, - CompilationLevel, VllmConfig, set_current_vllm_config, ) +from vllm.config.compilation import CompilationConfig, CompilationMode +from vllm.config.scheduler import SchedulerConfig from vllm.forward_context import set_forward_context -from vllm.utils import direct_register_custom_op # create a library to hold the custom op silly_lib = Library("silly", "FRAGMENT") # noqa @@ -22,29 +27,6 @@ MLP_SIZE = 128 -def silly_attention( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor -) -> None: - out.copy_(q) - out += k - out += v - - -def silly_attention_fake( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor -) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) - - @support_torch_compile class TestModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: @@ -67,12 +49,37 @@ def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]) model(torch.randn(batch_size, MLP_SIZE).cuda()) +class PostGradPassManagerCheckRanges(CustomGraphPass): + def __init__(self, ranges: list[tuple[int, int]]): + self.ranges = ranges + + def __call__(self, graph: fx.Graph): + compile_range = get_pass_context().compile_range + assert compile_range in self.ranges, ( + f"Compile range {compile_range} not in {self.ranges}" + ) + + def uuid(self) -> str: + state = { + "ranges": self.ranges, + } + return InductorPass.hash_dict(state) + + def test_compile_ranges(): vllm_config = VllmConfig( + scheduler_config=SchedulerConfig( + max_num_batched_tokens=8192, + ), compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, compile_ranges_split_points=[8, 32], - ) + ), + inductor_compile_config={ + "post_grad_custom_post_pass": PostGradPassManagerCheckRanges( + [(1, 8), (8, 32), (32, 2049)] + ) + }, ) with set_current_vllm_config(vllm_config): @@ -82,7 +89,7 @@ def test_compile_ranges(): with compilation_counter.expect( num_graphs_seen=1, num_piecewise_graphs_seen=1, - num_backend_compilations=4, + num_backend_compilations=3, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): run_model(vllm_config, model, batch_sizes) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 9c20db07c267..aaf53c6e5768 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -1109,6 +1109,11 @@ def __init__(self, config: VllmConfig): self.max_token_num = min( self.max_token_num, config.scheduler_config.max_num_batched_tokens ) + logger.debug_once( + f"Flashinfer max size: {max_size // (1024 * 1024)} MB" + f", Maximal number of tokens: {self.max_token_num}", + scope="global", + ) self.ipc_handles, workspace_tensor = ( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index fa728c23d145..6e50493a770c 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -143,7 +143,6 @@ def flashinfer_max_size(self, world_size: int) -> int | None: if max_size_mb is None: max_size_mb = self.default_fi_allreduce_fusion_max_size_mb().get(world_size) max_size_bytes = int(max_size_mb * MiB) if max_size_mb is not None else None - logger.debug_once(f"flashinfer_max_size: {max_size_bytes}", scope="global") return max_size_bytes @staticmethod From 1f7afdb3a784e5803161f1927b99a0c1f7d47575 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 5 Nov 2025 14:29:57 +0000 Subject: [PATCH 098/137] Add debug log Signed-off-by: ilmarkov --- vllm/compilation/collective_fusion.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index cf89182357f2..0a56ada54165 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -1140,6 +1140,12 @@ def __init__(self, config: VllmConfig): max_token_num = min( max_token_num, config.scheduler_config.max_num_batched_tokens ) + logger.debug_once( + f"Flashinfer max size: {max_size // (1024 * 1024)} MB," + "Maximal number of tokens used by " + f"Flashinfer Allreduce Fusion: {self.max_token_num}", + scope="global", + ) self.ipc_handles, workspace_tensor = ( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( From df22202272995c4a9c99f1ae7c562416d9620e53 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 5 Nov 2025 11:25:17 -0500 Subject: [PATCH 099/137] Update benchmark Signed-off-by: ilmarkov --- benchmarks/kernels/benchmark_fused_collective.py | 16 ++++++++++++---- vllm/config/compilation.py | 2 +- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py index cec134ff9138..d7fa0580a3e7 100644 --- a/benchmarks/kernels/benchmark_fused_collective.py +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -410,6 +410,7 @@ def run_benchmarks( use_residual: bool, allreduce_params: FlashInferFusedAllReduceParams | None, quant_modes: set[str], + no_oneshot: bool, ): """Run all benchmarks for given configuration. @@ -431,6 +432,7 @@ def run_benchmarks( rms_eps = 1e-6 results = {} vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype) + use_oneshot_options = [False] if no_oneshot else [True, False] # Create RMSNorm and QuantFP8 layers once for native benchmarks @@ -476,7 +478,7 @@ def run_benchmarks( # FlashInfer Fused AllReduce + RMSNorm Oneshot/Twoshot if flashinfer_comm is not None and allreduce_params is not None: - for use_oneshot in [True, False]: + for use_oneshot in use_oneshot_options: suffix = "_oneshot" if use_oneshot else "_twoshot" try: time_ms = benchmark_operation( @@ -560,7 +562,7 @@ def run_benchmarks( # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot if flashinfer_comm is not None and allreduce_params is not None: - for use_oneshot in [True, False]: + for use_oneshot in use_oneshot_options: suffix = "_oneshot" if use_oneshot else "_twoshot" try: time_ms = benchmark_operation( @@ -645,7 +647,7 @@ def run_benchmarks( # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot if flashinfer_comm is not None and allreduce_params is not None: - for use_oneshot in [True, False]: + for use_oneshot in use_oneshot_options: suffix = "_oneshot" if use_oneshot else "_twoshot" try: time_ms = benchmark_operation( @@ -901,7 +903,7 @@ def save_results_to_file( try: markdown_content = format_results_markdown(all_results, world_size, args) - with open(output_path, "w") as f: + with open(output_path, "a") as f: f.write(markdown_content) except Exception as e: @@ -960,6 +962,12 @@ def main(): """, ) + parser.add_argument( + "--no-oneshot", + action="store_true", + help="Skip oneshot benchmarks", + ) + args = parser.parse_args() # Check if running with torchrun (required for collective operations) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 6e50493a770c..6f35673856df 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -923,7 +923,7 @@ def custom_op_log_check(self): def get_compile_ranges(self) -> list[tuple[int, int]]: """Get the compile ranges for the compilation config.""" - split_points = self.compile_ranges_split_points + split_points = set(self.compile_ranges_split_points) compile_ranges = [] for i, s in enumerate(split_points): if i == 0: From a21de2baef2202f2610788027c904f9b377752e9 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 5 Nov 2025 16:32:59 +0000 Subject: [PATCH 100/137] Fix Signed-off-by: ilmarkov --- benchmarks/kernels/benchmark_fused_collective.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py index d7fa0580a3e7..99213d0c7cc2 100644 --- a/benchmarks/kernels/benchmark_fused_collective.py +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -1076,6 +1076,7 @@ def main(): use_residual, allreduce_params, quant_modes=quant_modes, + no_oneshot=args.no_oneshot, ) # Store results for markdown export From 45f4093548f0b5e89623906e507440d27ef6732b Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 5 Nov 2025 20:39:14 +0000 Subject: [PATCH 101/137] Update bench and constants Signed-off-by: ilmarkov --- .../kernels/benchmark_fused_collective.py | 17 +++++++++++++---- vllm/compilation/collective_fusion.py | 8 ++++---- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py index cec134ff9138..99213d0c7cc2 100644 --- a/benchmarks/kernels/benchmark_fused_collective.py +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -410,6 +410,7 @@ def run_benchmarks( use_residual: bool, allreduce_params: FlashInferFusedAllReduceParams | None, quant_modes: set[str], + no_oneshot: bool, ): """Run all benchmarks for given configuration. @@ -431,6 +432,7 @@ def run_benchmarks( rms_eps = 1e-6 results = {} vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype) + use_oneshot_options = [False] if no_oneshot else [True, False] # Create RMSNorm and QuantFP8 layers once for native benchmarks @@ -476,7 +478,7 @@ def run_benchmarks( # FlashInfer Fused AllReduce + RMSNorm Oneshot/Twoshot if flashinfer_comm is not None and allreduce_params is not None: - for use_oneshot in [True, False]: + for use_oneshot in use_oneshot_options: suffix = "_oneshot" if use_oneshot else "_twoshot" try: time_ms = benchmark_operation( @@ -560,7 +562,7 @@ def run_benchmarks( # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot if flashinfer_comm is not None and allreduce_params is not None: - for use_oneshot in [True, False]: + for use_oneshot in use_oneshot_options: suffix = "_oneshot" if use_oneshot else "_twoshot" try: time_ms = benchmark_operation( @@ -645,7 +647,7 @@ def run_benchmarks( # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot if flashinfer_comm is not None and allreduce_params is not None: - for use_oneshot in [True, False]: + for use_oneshot in use_oneshot_options: suffix = "_oneshot" if use_oneshot else "_twoshot" try: time_ms = benchmark_operation( @@ -901,7 +903,7 @@ def save_results_to_file( try: markdown_content = format_results_markdown(all_results, world_size, args) - with open(output_path, "w") as f: + with open(output_path, "a") as f: f.write(markdown_content) except Exception as e: @@ -960,6 +962,12 @@ def main(): """, ) + parser.add_argument( + "--no-oneshot", + action="store_true", + help="Skip oneshot benchmarks", + ) + args = parser.parse_args() # Check if running with torchrun (required for collective operations) @@ -1068,6 +1076,7 @@ def main(): use_residual, allreduce_params, quant_modes=quant_modes, + no_oneshot=args.no_oneshot, ) # Store results for markdown export diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 0a56ada54165..52ec5e5fd3b3 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -451,11 +451,11 @@ def __call__(self, graph: fx.Graph): # Max size of the input tensor per world size per device capability # to use flashinfer fused allreduce -FI_ALLREDUCE_FUSION_MAX_SIZE_MB = { +FI_ALLREDUCE_FUSION_MAX_SIZE_MB: dict[str, dict[int, float]] = { "9.0": { 2: 64, # 64MB 4: 2, # 2MB - 8: 1, # 1MB + 8: 0.5, # 0.5MB }, "10.0": { 2: 64, # 64MB @@ -466,11 +466,11 @@ def __call__(self, graph: fx.Graph): # Max size of the input tensor per world size per device capability # to use flashinfer one shot fused allreduce -_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB = { +_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[str, dict[int, float]] = { "9.0": { 2: 32, # 32MB 4: 2, # 2MB - 8: 1, # 1MB + 8: 0.5, # 0.5MB }, "10.0": { 2: 32, # 32MB From c26e056066a8c27af09a6ab8a44bb3aee7192bbd Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 5 Nov 2025 16:04:58 -0500 Subject: [PATCH 102/137] Rename in benchmark Signed-off-by: ilmarkov --- .../kernels/benchmark_fused_collective.py | 44 +++++++++++-------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py index 99213d0c7cc2..38e7fdcf5542 100644 --- a/benchmarks/kernels/benchmark_fused_collective.py +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -331,10 +331,10 @@ def allreduce_rmsnorm_fp4_quant( def create_test_tensors( - seq_len: int, hidden_dim: int, dtype: torch.dtype, use_residual: bool = True + num_tokens: int, hidden_dim: int, dtype: torch.dtype, use_residual: bool = True ): """Create test tensors for benchmarking.""" - input_tensor = torch.randn(seq_len, hidden_dim, dtype=dtype) + input_tensor = torch.randn(num_tokens, hidden_dim, dtype=dtype) residual = ( torch.randn_like(input_tensor) if use_residual @@ -348,7 +348,7 @@ def create_test_tensors( scale_fp4 = torch.tensor(1.0, dtype=torch.float32) quant_out_fp8 = torch.empty_like(input_tensor, dtype=FP8_DTYPE) # Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks) - fp4_quant_out = torch.empty((seq_len, hidden_dim // 2), dtype=torch.uint8) + fp4_quant_out = torch.empty((num_tokens, hidden_dim // 2), dtype=torch.uint8) fp4_output_scale = torch.empty((128, 4), dtype=torch.int32) return ( @@ -404,7 +404,7 @@ def benchmark_operation( def run_benchmarks( - seq_len: int, + num_tokens: int, hidden_dim: int, dtype: torch.dtype, use_residual: bool, @@ -427,7 +427,7 @@ def run_benchmarks( scale_fp4, fp4_quant_out, fp4_output_scale, - ) = create_test_tensors(seq_len, hidden_dim, dtype, use_residual) + ) = create_test_tensors(num_tokens, hidden_dim, dtype, use_residual) rms_eps = 1e-6 results = {} @@ -806,12 +806,18 @@ def get_fastest_baseline(op_name, results_dict): def print_results( - results_dict, seq_len, hidden_dim, dtype, use_residual, quant_modes, input_size_mb + results_dict, + num_tokens, + hidden_dim, + dtype, + use_residual, + quant_modes, + input_size_mb, ): """Print benchmark results in a formatted table.""" print(f"\n{'=' * 80}") print( - f"Results: seq_len={seq_len}, hidden_dim={hidden_dim} " + f"Results: num_tokens={num_tokens}, hidden_dim={hidden_dim} " f"(input size: {input_size_mb:.2f} MB)" ) print( @@ -854,7 +860,7 @@ def format_results_markdown( lines.append("") for entry in all_results: - seq_len = entry["seq_len"] + num_tokens = entry["num_tokens"] dtype = entry["dtype"] use_residual = entry["use_residual"] results_dict = entry["results"] @@ -862,7 +868,7 @@ def format_results_markdown( residual_str = "with residual" if use_residual else "no residual" lines.append( - f"## Configuration: seq_len={seq_len}, dtype={dtype}, {residual_str}" + f"## Configuration: num_tokens={num_tokens}, dtype={dtype}, {residual_str}" ) lines.append(f"**Input Size:** {input_size_mb:.2f} MB") lines.append("") @@ -915,11 +921,11 @@ def main(): description="Benchmark fused collective operations" ) parser.add_argument( - "--seq-lens", + "--num-tokens", type=int, nargs="+", default=[128, 512, 1024, 2048], - help="Sequence lengths to test", + help="Numbers of tokens to test", ) parser.add_argument( "--hidden-dim", type=int, default=8192, help="Hidden dimension size" @@ -1031,7 +1037,7 @@ def main(): # Test configurations residual_options = [True] if not args.no_residual else [False] - configs = list(itertools.product(args.seq_lens, dtypes, residual_options)) + configs = list(itertools.product(args.num_tokens, dtypes, residual_options)) # Setup FlashInfer workspace if available ipc_handles = None @@ -1059,18 +1065,18 @@ def main(): try: # Run benchmarks - for seq_len, dtype, use_residual in configs: + for num_tokens, dtype, use_residual in configs: if rank == 0: logger.info( - "\nTesting: seq_len=%s, hidden_dim=%s, dtype=%s, residual=%s", - seq_len, + "\nTesting: num_tokens=%s, hidden_dim=%s, dtype=%s, residual=%s", + num_tokens, args.hidden_dim, dtype, use_residual, ) results = run_benchmarks( - seq_len, + num_tokens, args.hidden_dim, dtype, use_residual, @@ -1083,11 +1089,11 @@ def main(): if rank == 0: # Calculate input size in MB input_size_mb = ( - seq_len * args.hidden_dim * torch.finfo(dtype).bits + num_tokens * args.hidden_dim * torch.finfo(dtype).bits ) / (8 * 1024 * 1024) all_results.append( { - "seq_len": seq_len, + "num_tokens": num_tokens, "hidden_dim": args.hidden_dim, "dtype": str(dtype).replace("torch.", ""), "use_residual": use_residual, @@ -1099,7 +1105,7 @@ def main(): print_results( results, - seq_len, + num_tokens, args.hidden_dim, dtype, use_residual, From bcc0cc0fd9de0c702ab6c8dee3fd2479ffd73bb9 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 5 Nov 2025 21:49:18 +0000 Subject: [PATCH 103/137] Add max_token_num to object Signed-off-by: ilmarkov --- vllm/compilation/collective_fusion.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 52ec5e5fd3b3..fa8d705eb51a 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -1134,11 +1134,11 @@ def __init__(self, config: VllmConfig): ) return element_size = 4 if use_fp32_lamport else 2 - max_token_num = max_size // (self.hidden_dim * element_size) + self.max_token_num = max_size // (self.hidden_dim * element_size) # take the min to save workspace size and we'll never use more # than max_num_batched_tokens anyways - max_token_num = min( - max_token_num, config.scheduler_config.max_num_batched_tokens + self.max_token_num = min( + self.max_token_num, config.scheduler_config.max_num_batched_tokens ) logger.debug_once( f"Flashinfer max size: {max_size // (1024 * 1024)} MB," @@ -1151,7 +1151,7 @@ def __init__(self, config: VllmConfig): flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( tp_rank=rank, tp_size=self.tp_size, - max_token_num=max_token_num, + max_token_num=self.max_token_num, hidden_dim=self.hidden_dim, group=self.group, use_fp32_lamport=use_fp32_lamport, @@ -1164,7 +1164,7 @@ def __init__(self, config: VllmConfig): rank=rank, world_size=self.tp_size, use_fp32_lamport=use_fp32_lamport, - max_token_num=max_token_num, + max_token_num=self.max_token_num, ) self.register_patterns() From 43b163c5bad69b83268acc5eee8ef6aa0375e117 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 5 Nov 2025 22:06:02 +0000 Subject: [PATCH 104/137] Add test Signed-off-by: ilmarkov --- .buildkite/test-pipeline.yaml | 4 ++-- tests/compile/test_fusions_e2e.py | 12 +++++++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 23830c4e0781..c6cc873b44d7 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -462,8 +462,8 @@ steps: - pytest -v -s compile/test_basic_correctness.py - pytest -v -s compile/piecewise/ -- label: PyTorch Fullgraph Test # 22min - timeout_in_minutes: 35 +- label: PyTorch Fullgraph Test # 27min + timeout_in_minutes: 40 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index d66c60ccb5b2..3c54c2adbf48 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -73,6 +73,16 @@ class ModelBackendTestCase(NamedTuple): ), ] + MODELS_MOE = [ + ModelBackendTestCase( + model_name="Qwen/Qwen3-30B-A3B", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.TRITON_ATTN, + attention_fusions=0, + allreduce_fusions=97, + ), + ] + elif current_platform.is_rocm(): MODELS_FP8 = [ ModelBackendTestCase( @@ -191,7 +201,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: ) ) # Toggle RMSNorm for FP4 models and unquant models - + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), + + list(flat_product(MODELS_FP4 + MODELS + MODELS_MOE, CUSTOM_OPS_RMS_NORM)), ) @pytest.mark.parametrize("inductor_graph_partition", [True, False]) @pytest.mark.skipif( From 71c6b7260ded8410bbfeaab937179808ed4acbcb Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 6 Nov 2025 10:19:25 +0000 Subject: [PATCH 105/137] Update comments Signed-off-by: ilmarkov --- vllm/compilation/collective_fusion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index fa8d705eb51a..5bf93d018bd5 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -466,6 +466,7 @@ def __call__(self, graph: fx.Graph): # Max size of the input tensor per world size per device capability # to use flashinfer one shot fused allreduce +# OneShot max size is at most 64MB / world size (FlashInfer restriction) _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[str, dict[int, float]] = { "9.0": { 2: 32, # 32MB From 6766e4f7da7914d7b1a24e6d760f56e181d5fbaa Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 5 Nov 2025 17:15:46 -0500 Subject: [PATCH 106/137] Update fakify for compile sizes Signed-off-by: ilmarkov --- vllm/compilation/piecewise_backend.py | 9 ++++++++- vllm/config/compilation.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index ad5b49f28550..fe35aaa9e4ae 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -122,14 +122,21 @@ def fakify_args(self, args: list[Any]) -> list[Any]: return fake_example_inputs def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: + is_compile_size = lambda range: range[0] == range[1] if not range_entry.compiled: range_entry.compiled = True self.to_be_compiled_ranges.remove(range_entry.compile_range) # args are real arguments + # fakify for range, real args for concrete size + args = ( + self.fakify_args(args) + if not is_compile_size(range_entry.compile_range) + else args + ) range_entry.runnable = self.vllm_backend.compiler_manager.compile( self.graph, - self.fakify_args(args), + args, self.compilation_config.inductor_compile_config, self.compilation_config, graph_index=self.piecewise_compile_index, diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 740b970669ed..67cd974a13e7 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -947,7 +947,7 @@ def custom_op_log_check(self): def get_compile_ranges(self) -> list[tuple[int, int]]: """Get the compile ranges for the compilation config.""" - split_points = set(self.compile_ranges_split_points) + split_points = sorted(set(self.compile_ranges_split_points)) compile_ranges = [] for i, s in enumerate(split_points): if i == 0: From af87d7a7996dc857933ce38b8be3badbed95a935 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 6 Nov 2025 09:59:37 -0500 Subject: [PATCH 107/137] Linter fix Signed-off-by: ilmarkov --- vllm/config/compilation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 67cd974a13e7..3a3fdd7f295d 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -947,6 +947,8 @@ def custom_op_log_check(self): def get_compile_ranges(self) -> list[tuple[int, int]]: """Get the compile ranges for the compilation config.""" + if self.compile_ranges_split_points is None: + return [] split_points = sorted(set(self.compile_ranges_split_points)) compile_ranges = [] for i, s in enumerate(split_points): From 2785e4d08873cc5456475d6e89c490194712990e Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Fri, 7 Nov 2025 09:17:41 +0000 Subject: [PATCH 108/137] Minor updates Signed-off-by: ilmarkov --- tests/compile/test_fusions_e2e.py | 5 +---- vllm/compilation/collective_fusion.py | 16 +++++++--------- vllm/config/compilation.py | 8 +++++--- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index 3c54c2adbf48..cf6a3b2de169 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -71,9 +71,6 @@ class ModelBackendTestCase(NamedTuple): attention_fusions=0, allreduce_fusions=65, ), - ] - - MODELS_MOE = [ ModelBackendTestCase( model_name="Qwen/Qwen3-30B-A3B", model_kwargs=dict(max_model_len=1024), @@ -201,7 +198,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: ) ) # Toggle RMSNorm for FP4 models and unquant models - + list(flat_product(MODELS_FP4 + MODELS + MODELS_MOE, CUSTOM_OPS_RMS_NORM)), + + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), ) @pytest.mark.parametrize("inductor_graph_partition", [True, False]) @pytest.mark.skipif( diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 5bf93d018bd5..69d4606d73eb 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -451,13 +451,13 @@ def __call__(self, graph: fx.Graph): # Max size of the input tensor per world size per device capability # to use flashinfer fused allreduce -FI_ALLREDUCE_FUSION_MAX_SIZE_MB: dict[str, dict[int, float]] = { - "9.0": { +FI_ALLREDUCE_FUSION_MAX_SIZE_MB: dict[int, dict[int, float]] = { + 90: { 2: 64, # 64MB 4: 2, # 2MB 8: 0.5, # 0.5MB }, - "10.0": { + 100: { 2: 64, # 64MB 4: 32, # 32MB 8: 1, # 1MB @@ -467,13 +467,13 @@ def __call__(self, graph: fx.Graph): # Max size of the input tensor per world size per device capability # to use flashinfer one shot fused allreduce # OneShot max size is at most 64MB / world size (FlashInfer restriction) -_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[str, dict[int, float]] = { - "9.0": { +_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[int, dict[int, float]] = { + 90: { 2: 32, # 32MB 4: 2, # 2MB 8: 0.5, # 0.5MB }, - "10.0": { + 100: { 2: 32, # 32MB 4: 4, # 4MB 8: 1, # 1MB @@ -507,9 +507,7 @@ def call_trtllm_fused_allreduce_norm( current_tensor_size = num_tokens * hidden_size * element_size if num_tokens <= max_token_num: - device_capability = ( - current_platform.get_device_capability().as_version_str() - ) + device_capability = current_platform.get_device_capability().to_int() # Get one shot input size limit for the current world size # for the current device capability max_one_shot_size_mb = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get( diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 015b34fbe094..7606d27c04a1 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -118,12 +118,12 @@ class PassConfig: Unspecified will fallback to default values which are compute capability and world size dependent. FI_ALLREDUCE_FUSION_MAX_SIZE_MB = { - "9.0": { + 90: { 2: 64, # 64MB 4: 2, # 2MB 8: 1, # 1MB }, - "10.0": { + 100: { 2: 64, # 64MB 4: 32, # 32MB 8: 1, # 1MB @@ -151,8 +151,10 @@ def default_fi_allreduce_fusion_max_size_mb() -> dict[int, float]: from vllm.compilation.collective_fusion import FI_ALLREDUCE_FUSION_MAX_SIZE_MB from vllm.platforms import current_platform + if not current_platform.is_cuda(): + return {} return FI_ALLREDUCE_FUSION_MAX_SIZE_MB.get( - current_platform.get_device_capability().as_version_str(), {} + current_platform.get_device_capability().to_int(), {} ) def uuid(self): From b4c1b1d66d6ce3288c65c57251d0492f2e9f475b Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 10 Nov 2025 12:31:48 +0000 Subject: [PATCH 109/137] Address the review Signed-off-by: ilmarkov --- tests/compile/test_compile_ranges.py | 50 +++++++++++++----------- vllm/compilation/backends.py | 12 +++--- vllm/compilation/collective_fusion.py | 9 +++-- vllm/compilation/compiler_interface.py | 21 +++++----- vllm/compilation/inductor_pass.py | 9 +++-- vllm/compilation/pass_manager.py | 4 +- vllm/compilation/piecewise_backend.py | 27 ++++++------- vllm/compilation/sequence_parallelism.py | 7 ++-- vllm/config/compilation.py | 8 ++-- vllm/config/utils.py | 36 ++++++++++++++++- vllm/config/vllm.py | 6 ++- vllm/v1/worker/gpu_worker.py | 19 ++++++++- 12 files changed, 137 insertions(+), 71 deletions(-) diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py index 03f31df1ece7..564690f18192 100644 --- a/tests/compile/test_compile_ranges.py +++ b/tests/compile/test_compile_ranges.py @@ -3,12 +3,11 @@ import torch from torch import fx as fx from torch import nn -from torch.library import Library +import tests.compile.silly_attention # noqa from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.compilation.inductor_pass import ( - CustomGraphPass, InductorPass, get_pass_context, ) @@ -18,11 +17,9 @@ ) from vllm.config.compilation import CompilationConfig, CompilationMode from vllm.config.scheduler import SchedulerConfig +from vllm.config.utils import Range from vllm.forward_context import set_forward_context -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa - BATCH_SIZE = 64 MLP_SIZE = 128 @@ -49,24 +46,34 @@ def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]) model(torch.randn(batch_size, MLP_SIZE).cuda()) -class PostGradPassManagerCheckRanges(CustomGraphPass): - def __init__(self, ranges: list[tuple[int, int]]): +class PostGradPassManagerCheckRanges(InductorPass): + def __init__(self, ranges: list[Range]): self.ranges = ranges + self.num_calls = 0 def __call__(self, graph: fx.Graph): compile_range = get_pass_context().compile_range assert compile_range in self.ranges, ( f"Compile range {compile_range} not in {self.ranges}" ) + self.num_calls += 1 def uuid(self) -> str: state = { - "ranges": self.ranges, + "ranges": [str(range) for range in self.ranges], + "current_compile_range": str(get_pass_context().compile_range), } return InductorPass.hash_dict(state) def test_compile_ranges(): + post_grad_pass_manager = PostGradPassManagerCheckRanges( + [ + Range(start=1, end=8), + Range(start=8, end=32), + Range(start=32, end=8193), + ] + ) vllm_config = VllmConfig( scheduler_config=SchedulerConfig( max_num_batched_tokens=8192, @@ -74,22 +81,21 @@ def test_compile_ranges(): compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, compile_ranges_split_points=[8, 32], + inductor_compile_config={ + "post_grad_custom_post_pass": post_grad_pass_manager + }, ), - inductor_compile_config={ - "post_grad_custom_post_pass": PostGradPassManagerCheckRanges( - [(1, 8), (8, 32), (32, 2049)] - ) - }, ) with set_current_vllm_config(vllm_config): model = TestModel(vllm_config=vllm_config, prefix="").eval().cuda() - batch_sizes = [1, 16, 48] - # A has support_torch_compile - with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=1, - num_backend_compilations=3, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen - ): - run_model(vllm_config, model, batch_sizes) + batch_sizes = [1, 16, 48] + # A has support_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=1, + num_backend_compilations=3, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, model, batch_sizes) + assert post_grad_pass_manager.num_calls == 3 diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 7a1d851ebe42..0d7ef88c8e6a 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -22,6 +22,7 @@ resolve_defined_ops, ) from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig +from vllm.config.utils import Range from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.import_utils import resolve_obj_by_qualname @@ -83,7 +84,7 @@ class CompilerManager: """ def __init__(self, compilation_config: CompilationConfig): - self.cache: dict[tuple[tuple[int, int] | None, int, str], Any] = dict() + self.cache: dict[tuple[Range | None, int, str], Any] = dict() self.is_cache_updated = False self.compilation_config = compilation_config self.compiler = make_compiler(compilation_config) @@ -92,7 +93,7 @@ def compute_hash(self, vllm_config: VllmConfig) -> str: return self.compiler.compute_hash(vllm_config) @contextmanager - def compile_context(self, compile_range: tuple[int, int] | None = None): + def compile_context(self, compile_range: Range | None = None): """Provide compilation context for the duration of compilation to set any torch global properties we want to scope to a single Inductor compilation (e.g. partition rules, pass context).""" @@ -152,7 +153,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, ) -> Callable | None: if (compile_range, graph_index, self.compiler.name) not in self.cache: return None @@ -187,7 +188,7 @@ def compile( compilation_config: CompilationConfig, graph_index: int = 0, num_graphs: int = 1, - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, ) -> Any: if graph_index == 0: # before compiling the first graph, record the start time @@ -206,6 +207,7 @@ def compile( # there can be multiple graphs due to piecewise compilation. now = time.time() elapsed = now - compilation_start_time + compilation_config.compilation_time += elapsed if compile_range is None: logger.info( "Directly load the compiled graph(s) for dynamic shape " @@ -231,7 +233,7 @@ def compile( if compile_range is None: maybe_key += "dynamic_shape" else: - maybe_key += f"{compile_range[0]}_{compile_range[1]}" + maybe_key += f"{compile_range.start}_{compile_range.end}" maybe_key += f"_subgraph_{graph_index}" with self.compile_context(compile_range): compiled_graph, handle = self.compiler.compile( diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index dbe17f984808..81e881373e45 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -10,6 +10,7 @@ from torch.distributed._symmetric_memory import enable_symm_mem_for_group from vllm.config import VllmConfig +from vllm.config.utils import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, @@ -431,7 +432,7 @@ def __init__(self, config: VllmConfig): self.dump_patterns(config, self.patterns) - def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: + def is_applicable_for_range(self, compile_range: Range | None) -> bool: # This pass is applied on top of the sequence parallelism pass. # It inherits the same applicability condition as `SequenceParallelismPass`. # See `SequenceParallelismPass.is_applicable` for more details. @@ -442,7 +443,7 @@ def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool return True tp_size = get_tensor_model_parallel_world_size() return compile_range is not None and ( - compile_range[0] == compile_range[1] and compile_range[1] % tp_size == 0 + compile_range.is_single_size() and compile_range.end % tp_size == 0 ) @VllmInductorPass.time_and_log @@ -1188,10 +1189,10 @@ def register_patterns(self): self.disabled = False - def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: + def is_applicable_for_range(self, compile_range: Range | None) -> bool: if compile_range is None: return False - return compile_range[1] - 1 <= self.max_token_num + return compile_range.end - 1 <= self.max_token_num @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 6124a5428f6c..b95067aba191 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -16,6 +16,7 @@ import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.config import VllmConfig +from vllm.config.utils import Range from vllm.utils.torch_utils import is_torch_equal_or_newer @@ -63,7 +64,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: """ @@ -99,7 +100,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, ) -> Callable: """ Load the compiled function from the handle. @@ -213,7 +214,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -223,8 +224,8 @@ def compile( set_inductor_config(current_config, compile_range) set_functorch_config() - if isinstance(compile_range, tuple): - if compile_range[0] == compile_range[1]: + if compile_range is not None: + if compile_range.is_single_size(): dynamic_shapes = "from_example_inputs" else: dynamic_shapes = "from_graph" @@ -254,7 +255,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -318,7 +319,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -515,7 +516,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -612,7 +613,7 @@ def metrics_context(self) -> contextlib.AbstractContextManager: def set_inductor_config(config, compile_range): - if isinstance(compile_range, tuple) and compile_range[0] == compile_range[1]: + if compile_range is not None and compile_range.is_single_size(): # for a specific batch size, tuning triton kernel parameters # can be beneficial config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE @@ -633,7 +634,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_eager_compiles += 1 diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 599fa776b6c0..008eba4629a3 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -14,6 +14,7 @@ from torch import fx from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily +from vllm.config.utils import Range from vllm.utils.torch_utils import is_torch_equal_or_newer if is_torch_equal_or_newer("2.6"): @@ -28,8 +29,8 @@ class PassContext: - def __init__(self, compile_range: tuple[int, int] | None): - self.compile_range = compile_range + def __init__(self, compile_range: Range | None): + self.compile_range: Range | None = compile_range def get_pass_context() -> PassContext: @@ -39,7 +40,7 @@ def get_pass_context() -> PassContext: @contextmanager -def pass_context(compile_range: tuple[int, int] | None): +def pass_context(compile_range: Range | None): """A context manager that stores the current pass context, usually it is a list of sizes to specialize. """ @@ -96,7 +97,7 @@ def hash_dict(dict_: dict[Any, Any]): encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() - def is_applicable_for_range(self, compile_range: tuple[int, int] | None): + def is_applicable_for_range(self, compile_range: Range | None): return True diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 5984f968da35..4664d0d9aefd 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -127,6 +127,8 @@ def uuid(self): for pass_ in self.passes: state["passes"].append(pass_.uuid()) state["passes"].append(self.fix_functionalization.uuid()) - state["compile_range"] = get_pass_context().compile_range + compile_range = get_pass_context().compile_range + if compile_range is not None: + state["compile_range"] = str(compile_range) return InductorPass.hash_dict(state) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index fe35aaa9e4ae..10844b69c455 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -10,6 +10,7 @@ from vllm.compilation.backends import VllmBackend from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig +from vllm.config.compilation import Range from vllm.logger import init_logger logger = init_logger(__name__) @@ -17,7 +18,7 @@ @dataclasses.dataclass class RangeEntry: - compile_range: tuple[int, int] + compile_range: Range compiled: bool = False runnable: Callable = None # type: ignore @@ -61,12 +62,6 @@ def __init__( log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}" logger.debug_once(log_string) - self.is_in_range = ( - lambda x, range: range[0] <= x < range[1] - if range[0] < range[1] - else x == range[0] - ) - self.first_run_finished = False self.sym_shape_indices = sym_shape_indices @@ -75,15 +70,15 @@ def __init__( # self.concrete_size_entries: dict[int, RangeEntry] = {} # the entries for ranges that we need to either - self.range_entries: dict[tuple[int, int], RangeEntry] = {} + self.range_entries: dict[Range, RangeEntry] = {} # to_be_compiled_ranges tracks the remaining ranges to compile, # and updates during the compilation process, so we need to copy it - self.to_be_compiled_ranges: set[tuple[int, int]] = set(self.compile_ranges) + self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges) # We only keep compilation management inside this class directly. for size in self.compile_sizes: - range = (size, size) + range = Range(start=size, end=size) self.range_entries[range] = RangeEntry( compile_range=range, ) @@ -122,7 +117,6 @@ def fakify_args(self, args: list[Any]) -> list[Any]: return fake_example_inputs def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: - is_compile_size = lambda range: range[0] == range[1] if not range_entry.compiled: range_entry.compiled = True self.to_be_compiled_ranges.remove(range_entry.compile_range) @@ -131,7 +125,7 @@ def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: # fakify for range, real args for concrete size args = ( self.fakify_args(args) - if not is_compile_size(range_entry.compile_range) + if not range_entry.compile_range.is_single_size() else args ) range_entry.runnable = self.vllm_backend.compiler_manager.compile( @@ -158,13 +152,18 @@ def __call__(self, *args) -> Any: return range_entry.runnable(*args) runtime_shape = args[self.sym_shape_indices[0]] + # First we try to find the range entry for the concrete compile size + # If not found, we search for the range entry + # that contains the runtime shape. range_found = False if runtime_shape in self.compile_sizes: - range_entry = self.range_entries[(runtime_shape, runtime_shape)] + range_entry = self.range_entries[ + Range(start=runtime_shape, end=runtime_shape) + ] range_found = True else: for range in self.compile_ranges: - if self.is_in_range(runtime_shape, range): + if range.contains(runtime_shape): range_entry = self.range_entries[range] range_found = True break diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index cf47adb4670a..6a5ee5a0efb7 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -7,6 +7,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from vllm.config import VllmConfig +from vllm.config.compilation import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -482,7 +483,7 @@ def __init__(self, config: VllmConfig): ).register(self.patterns) self.dump_patterns(config, self.patterns) - def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: + def is_applicable_for_range(self, compile_range: Range | None) -> bool: # When sequence parallelism is enabled, the residual tensor from RMSNorm # needs to be split along the sequence dimension. However, this dimension # is symbolic during piecewise compilation, and splitting symbolic shapes @@ -504,8 +505,8 @@ def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool tp_size = get_tensor_model_parallel_world_size() return ( compile_range is not None - and (compile_range[0] == compile_range[1]) - and (compile_range[1] % tp_size == 0) + and (compile_range.is_single_size()) + and (compile_range.end % tp_size == 0) ) @VllmInductorPass.time_and_log diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 2ae93c59ddfb..298fe4242a83 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -14,7 +14,7 @@ import vllm.envs as envs from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass -from vllm.config.utils import config +from vllm.config.utils import Range, config from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.import_utils import resolve_obj_by_qualname @@ -945,7 +945,7 @@ def custom_op_log_check(self): op, ) - def get_compile_ranges(self) -> list[tuple[int, int]]: + def get_compile_ranges(self) -> list[Range]: """Get the compile ranges for the compilation config.""" if self.compile_ranges_split_points is None: return [] @@ -953,7 +953,7 @@ def get_compile_ranges(self) -> list[tuple[int, int]]: compile_ranges = [] for i, s in enumerate(split_points): if i == 0: - compile_ranges.append((1, s)) + compile_ranges.append(Range(start=1, end=s)) else: - compile_ranges.append((split_points[i - 1], s)) + compile_ranges.append(Range(start=split_points[i - 1], end=s)) return compile_ranges diff --git a/vllm/config/utils.py b/vllm/config/utils.py index 7e0878d96bbd..7270caf02740 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -6,7 +6,7 @@ import inspect import textwrap from collections.abc import Iterable -from dataclasses import MISSING, Field, field, fields, is_dataclass, replace +from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace from itertools import pairwise from typing import TYPE_CHECKING, Any, Protocol, TypeVar @@ -176,3 +176,37 @@ def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT: ) processed_overrides[field_name] = value return replace(config, **processed_overrides) + + +@dataclass +class Range: + """ + A range of numbers. + Inclusive of start, exclusive of end. + """ + + start: int + end: int + + def is_single_size(self) -> bool: + return self.start == self.end + + def contains(self, size: int) -> bool: + # Inclusive of start, exclusive of end + if self.is_single_size(): + return size == self.start + return self.start <= size < self.end + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Range): + return False + return self.start == other.start and self.end == other.end + + def __hash__(self) -> int: + return hash((self.start, self.end)) + + def __str__(self) -> str: + return f"(start={self.start}, end={self.end})" + + def __repr__(self) -> str: + return self.__str__() diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 43a3b51b3a0a..a217b3c48f81 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -889,7 +889,11 @@ def _set_compile_ranges(self): # We add 1 because the bounds checks in the compiler are # exclusive and we want to include the max_token_num in the # compile range - computed_compile_ranges_split_points.append(max_token_num + 1) + if ( + max_num_batched_tokens is not None + and max_token_num < max_num_batched_tokens + ): + computed_compile_ranges_split_points.append(max_token_num + 1) if compilation_config.compile_ranges_split_points is not None: for x in compilation_config.compile_ranges_split_points: diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index f13ff4e726bd..42f9bdeab97e 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -14,7 +14,7 @@ import torch.nn as nn import vllm.envs as envs -from vllm.config import VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig from vllm.distributed import ( ensure_model_parallel_initialized, init_distributed_environment, @@ -398,12 +398,27 @@ def compile_or_warm_up_model(self) -> None: # but users still want to compile for better performance, # e.g. for the max-num-batched token size in chunked prefill. warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() - if not self.model_config.enforce_eager: + + if ( + not self.model_config.enforce_eager + or self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE + ): warmup_sizes = [ x for x in warmup_sizes if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes ] + compile_ranges = self.vllm_config.compilation_config.get_compile_ranges() + + # For each compile_range, if none of the batch sizes + # in warmup_sizes or cudagraph_capture_sizes are in the range, + # add the start of the range to ensure compilation/warmup. + all_sizes = set(self.vllm_config.compilation_config.cudagraph_capture_sizes) + all_sizes.update(warmup_sizes) + for compile_range in compile_ranges: + if not any(compile_range.contains(x) for x in all_sizes): + warmup_sizes.append(compile_range.start) + # We skip EPLB here since we don't want to record dummy metrics for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) From b0a38849c3bc655a2220419eae9c12277899a273 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 10 Nov 2025 17:08:55 +0000 Subject: [PATCH 110/137] Fix SP Signed-off-by: ilmarkov --- vllm/model_executor/layers/fused_moe/layer.py | 97 +++++++++---------- 1 file changed, 48 insertions(+), 49 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 62c4a11092ef..2b8280f941e3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -2118,6 +2118,14 @@ def set_eplb_state( self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] self.logical_replica_count = logical_replica_count[moe_layer_idx] + def get_sp_ctx(self): + ctx = get_forward_context() + return ( + ctx.dp_metadata.sp_local_sizes(self.sp_size) + if ctx.dp_metadata + else nullcontext() + ) + def ensure_moe_quant_config_init(self): if self.quant_method.moe_quant_config is None: self.quant_method.moe_quant_config = ( @@ -2336,22 +2344,20 @@ def forward_native( self.quant_method, FusedMoEModularMethod ) - ctx = get_forward_context() - sp_ctx = ( - ctx.dp_metadata.sp_local_sizes(self.sp_size) - if ctx.dp_metadata - else nullcontext() - ) - + sp_ctx = self.get_sp_ctx() with sp_ctx: if do_naive_dispatch_combine: hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits, self.is_sequence_parallel ) - def reduce_output( - states: torch.Tensor, do_combine: bool = True - ) -> torch.Tensor: + def reduce_output( + states: torch.Tensor, do_combine: bool = True + ) -> torch.Tensor: + # Reinitialize the context manager + # as it was reset in the forward_impl. + sp_ctx = self.get_sp_ctx() + with sp_ctx: if do_naive_dispatch_combine and do_combine: states = get_ep_group().combine(states, self.is_sequence_parallel) @@ -2364,41 +2370,39 @@ def reduce_output( states = self.maybe_all_reduce_tensor_model_parallel(states) return states - if self.shared_experts is None: - if current_platform.is_tpu(): - # TODO: Once the OOM issue for the TPU backend is resolved, we - # will switch to using the moe_forward custom op. - fused_output = self.forward_impl(hidden_states, router_logits) - assert not isinstance(fused_output, tuple) - else: - fused_output = torch.ops.vllm.moe_forward( - hidden_states, router_logits, self.layer_name - ) - if self.zero_expert_num is not None and self.zero_expert_num > 0: - assert isinstance(fused_output, tuple) - fused_output, zero_expert_result = fused_output - return (reduce_output(fused_output) + zero_expert_result)[ - ..., :og_hidden_states - ] - else: - return reduce_output(fused_output)[..., :og_hidden_states] + if self.shared_experts is None: + if current_platform.is_tpu(): + # TODO: Once the OOM issue for the TPU backend is resolved, we + # will switch to using the moe_forward custom op. + fused_output = self.forward_impl(hidden_states, router_logits) + assert not isinstance(fused_output, tuple) else: - if current_platform.is_tpu(): - # TODO: Once the OOM issue for the TPU backend is resolved, we - # will switch to using the moe_forward custom op. - shared_output, fused_output = self.forward_impl( - hidden_states, router_logits - ) - else: - shared_output, fused_output = torch.ops.vllm.moe_forward_shared( - hidden_states, router_logits, self.layer_name - ) - return ( - reduce_output(shared_output, do_combine=False)[ - ..., :og_hidden_states - ], - reduce_output(fused_output)[..., :og_hidden_states], + fused_output = torch.ops.vllm.moe_forward( + hidden_states, router_logits, self.layer_name + ) + if self.zero_expert_num is not None and self.zero_expert_num > 0: + assert isinstance(fused_output, tuple) + fused_output, zero_expert_result = fused_output + return (reduce_output(fused_output) + zero_expert_result)[ + ..., :og_hidden_states + ] + else: + return reduce_output(fused_output)[..., :og_hidden_states] + else: + if current_platform.is_tpu(): + # TODO: Once the OOM issue for the TPU backend is resolved, we + # will switch to using the moe_forward custom op. + shared_output, fused_output = self.forward_impl( + hidden_states, router_logits ) + else: + shared_output, fused_output = torch.ops.vllm.moe_forward_shared( + hidden_states, router_logits, self.layer_name + ) + return ( + reduce_output(shared_output, do_combine=False)[..., :og_hidden_states], + reduce_output(fused_output)[..., :og_hidden_states], + ) def forward_cuda( self, @@ -2625,12 +2629,7 @@ def forward_impl( else: shared_output = None - ctx = get_forward_context() - sp_ctx = ( - ctx.dp_metadata.sp_local_sizes(self.sp_size) - if ctx.dp_metadata - else nullcontext() - ) + sp_ctx = self.get_sp_ctx() with sp_ctx: # Matrix multiply. From 319abd5ee9c50b25a7929ba1e3e6588d44fc9d6d Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 12 Nov 2025 18:25:16 +0000 Subject: [PATCH 111/137] Remove dynamic shape Signed-off-by: ilmarkov --- tests/compile/test_compile_ranges.py | 41 ++++++--- vllm/compilation/backends.py | 107 +++++++---------------- vllm/compilation/collective_fusion.py | 10 +-- vllm/compilation/compiler_interface.py | 36 ++++---- vllm/compilation/inductor_pass.py | 8 +- vllm/compilation/pass_manager.py | 7 +- vllm/compilation/piecewise_backend.py | 2 +- vllm/compilation/sequence_parallelism.py | 8 +- vllm/config/compilation.py | 10 +-- vllm/config/utils.py | 6 +- vllm/config/vllm.py | 3 +- 11 files changed, 105 insertions(+), 133 deletions(-) diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py index bacaa48ae477..b15f90395c6a 100644 --- a/tests/compile/test_compile_ranges.py +++ b/tests/compile/test_compile_ranges.py @@ -42,9 +42,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @torch.inference_mode def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]): with set_forward_context({}, vllm_config=vllm_config): - model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) + model(torch.randn(BATCH_SIZE, MLP_SIZE)) for batch_size in batch_sizes: - model(torch.randn(batch_size, MLP_SIZE).cuda()) + model(torch.randn(batch_size, MLP_SIZE)) class PostGradPassManagerCheckRanges(InductorPass): @@ -70,11 +70,14 @@ def uuid(self) -> str: def test_compile_ranges(): post_grad_pass_manager = PostGradPassManagerCheckRanges( [ - Range(start=1, end=8), - Range(start=8, end=32), - Range(start=32, end=8193), + Range(start=1, end=9), + Range(start=16, end=16), + Range(start=9, end=33), + Range(start=64, end=64), + Range(start=33, end=8193), ] ) + torch.set_default_device("cuda") vllm_config = VllmConfig( scheduler_config=SchedulerConfig( max_num_batched_tokens=8192, @@ -82,6 +85,7 @@ def test_compile_ranges(): compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, compile_ranges_split_points=[8, 32], + compile_sizes=[16, 64, 128], inductor_compile_config={ "post_grad_custom_post_pass": post_grad_pass_manager, # Disable inductor cache to get the number of passes correctly @@ -91,14 +95,31 @@ def test_compile_ranges(): ) with set_current_vllm_config(vllm_config): - model = TestModel(vllm_config=vllm_config, prefix="").eval().cuda() - batch_sizes = [1, 4, 16, 24, 48, 64] + model = TestModel(vllm_config=vllm_config, prefix="").eval() + # Number of compilations: 3 for each compile range + 2 compile sizes + batch_sizes = [1, 4, 16, 24, 48, 64, 8192] # A has support_torch_compile with compilation_counter.expect( num_graphs_seen=1, num_piecewise_graphs_seen=1, - num_backend_compilations=3, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_backend_compilations=5, ): run_model(vllm_config, model, batch_sizes) - assert post_grad_pass_manager.num_calls == 3 + assert post_grad_pass_manager.num_calls == 5 + + +def test_compile_config_get_compile_ranges(): + compilation_config = CompilationConfig( + compile_ranges_split_points=[8, 32], + ) + VllmConfig( + scheduler_config=SchedulerConfig( + max_num_batched_tokens=8192, + ), + compilation_config=compilation_config, + ) + assert compilation_config.get_compile_ranges() == [ + Range(start=1, end=9), + Range(start=9, end=33), + Range(start=33, end=8193), + ] diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index efd68a71c7e4..b1fe58d08265 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -84,7 +84,7 @@ class CompilerManager: """ def __init__(self, compilation_config: CompilationConfig): - self.cache: dict[tuple[Range | None, int, str], Any] = dict() + self.cache: dict[tuple[Range, int, str], Any] = dict() self.is_cache_updated = False self.compilation_config = compilation_config self.compiler = make_compiler(compilation_config) @@ -93,7 +93,7 @@ def compute_hash(self, vllm_config: VllmConfig) -> str: return self.compiler.compute_hash(vllm_config) @contextmanager - def compile_context(self, compile_range: Range | None = None): + def compile_context(self, compile_range: Range): """Provide compilation context for the duration of compilation to set any torch global properties we want to scope to a single Inductor compilation (e.g. partition rules, pass context).""" @@ -153,7 +153,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: Range | None = None, + compile_range: Range, ) -> Callable | None: if (compile_range, graph_index, self.compiler.name) not in self.cache: return None @@ -161,23 +161,13 @@ def load( compiled_graph = self.compiler.load( handle, graph, example_inputs, graph_index, compile_range ) - if compile_range is None: - logger.debug( - "Directly load the %s-th graph for dynamic compile range" - "from %s via handle %s", - graph_index, - self.compiler.name, - handle, - ) - else: - logger.debug( - "Directly load the %s-th graph for compile range %s" - "from %s via handle %s", - graph_index, - str(compile_range), - self.compiler.name, - handle, - ) + logger.debug( + "Directly load the %s-th graph for compile range %sfrom %s via handle %s", + graph_index, + str(compile_range), + self.compiler.name, + handle, + ) return compiled_graph def compile( @@ -186,9 +176,9 @@ def compile( example_inputs, additional_inductor_config, compilation_config: CompilationConfig, + compile_range: Range, graph_index: int = 0, num_graphs: int = 1, - compile_range: Range | None = None, ) -> Any: if graph_index == 0: # before compiling the first graph, record the start time @@ -208,19 +198,12 @@ def compile( now = time.time() elapsed = now - compilation_start_time compilation_config.compilation_time += elapsed - if compile_range is None: - logger.info( - "Directly load the compiled graph(s) for dynamic shape " - "from the cache, took %.3f s", - elapsed, - ) - else: - logger.info( - "Directly load the compiled graph(s) for compile range %s " - "from the cache, took %.3f s", - str(compile_range), - elapsed, - ) + logger.info( + "Directly load the compiled graph(s) for compile range %s " + "from the cache, took %.3f s", + str(compile_range), + elapsed, + ) return compiled_graph # no compiler cached the graph, or the cache is disabled, @@ -230,10 +213,7 @@ def compile( maybe_key = None else: maybe_key = "artifact_compile_range_" - if compile_range is None: - maybe_key += "dynamic_shape" - else: - maybe_key += f"{compile_range.start}_{compile_range.end}" + maybe_key += f"{compile_range.start}_{compile_range.end}" maybe_key += f"_subgraph_{graph_index}" with self.compile_context(compile_range): compiled_graph, handle = self.compiler.compile( @@ -253,50 +233,29 @@ def compile( self.is_cache_updated = True if graph_index == 0: # adds some info logging for the first graph - if compile_range is None: - logger.info_once( - "Cache the graph for dynamic shape for later use", scope="local" - ) - else: - logger.info_once( - "Cache the graph of compile range %s for later use", - str(compile_range), - ) - if compile_range is None: - logger.debug( - "Store the %s-th graph for dynamic compile range" - "from %s via handle %s", - graph_index, - self.compiler.name, - handle, - ) - else: - logger.debug( - "Store the %s-th graph for compile range%s from %s via handle %s", - graph_index, + logger.info_once( + "Cache the graph of compile range %s for later use", str(compile_range), - self.compiler.name, - handle, ) + logger.debug( + "Store the %s-th graph for compile range%s from %s via handle %s", + graph_index, + str(compile_range), + self.compiler.name, + handle, + ) # after compiling the last graph, record the end time if graph_index == num_graphs - 1: now = time.time() elapsed = now - compilation_start_time compilation_config.compilation_time += elapsed - if compile_range is None: - logger.info_once( - "Compiling a graph for dynamic compile range takes %.2f s", - elapsed, - scope="local", - ) - else: - logger.info_once( - "Compiling a graph for compile range %s takes %.2f s", - str(compile_range), - elapsed, - scope="local", - ) + logger.info_once( + "Compiling a graph for compile range %s takes %.2f s", + str(compile_range), + elapsed, + scope="local", + ) return compiled_graph diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 32d1f1531f4c..bef8925661cd 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -432,7 +432,7 @@ def __init__(self, config: VllmConfig): self.dump_patterns(config, self.patterns) - def is_applicable_for_range(self, compile_range: Range | None) -> bool: + def is_applicable_for_range(self, compile_range: Range) -> bool: # This pass is applied on top of the sequence parallelism pass. # It inherits the same applicability condition as `SequenceParallelismPass`. # See `SequenceParallelismPass.is_applicable` for more details. @@ -442,9 +442,7 @@ def is_applicable_for_range(self, compile_range: Range | None) -> bool: ): return True tp_size = get_tensor_model_parallel_world_size() - return compile_range is not None and ( - compile_range.is_single_size() and compile_range.end % tp_size == 0 - ) + return compile_range.is_single_size() and compile_range.end % tp_size == 0 @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): @@ -1189,9 +1187,7 @@ def register_patterns(self): self.disabled = False - def is_applicable_for_range(self, compile_range: Range | None) -> bool: - if compile_range is None: - return False + def is_applicable_for_range(self, compile_range: Range) -> bool: return compile_range.end - 1 <= self.max_token_num @VllmInductorPass.time_and_log diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index b95067aba191..3bafba2e1642 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -64,16 +64,15 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: Range | None = None, + compile_range: Range, key: str | None = None, ) -> tuple[Callable | None, Any | None]: """ Compile the graph with the given example inputs and compiler config, - with a range. If the `compile_range` is None, it means - the `example_inputs` have a dynamic shape. Otherwise, the - `compile_range` specifies the range of the inputs, - it could be concrete size, e.g. (4, 4). - Right now we only support one variable range of shapes for all inputs, + with a range. The `compile_range` specifies the range of the inputs, + it could be concrete size (if compile_sizes is provided), e.g. [4, 4) + or a range [4, 5). + Right now we only support one variable in ranges for all inputs, which is the batchsize (number of tokens) during inference. Dynamo will make sure `graph(*example_inputs)` is valid. @@ -100,7 +99,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: Range | None = None, + compile_range: Range, ) -> Callable: """ Load the compiled function from the handle. @@ -214,7 +213,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: Range | None = None, + compile_range: Range, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -224,13 +223,10 @@ def compile( set_inductor_config(current_config, compile_range) set_functorch_config() - if compile_range is not None: - if compile_range.is_single_size(): - dynamic_shapes = "from_example_inputs" - else: - dynamic_shapes = "from_graph" + if compile_range.is_single_size(): + dynamic_shapes = "from_example_inputs" else: - dynamic_shapes = "from_tracing_context" + dynamic_shapes = "from_graph" from torch._inductor import standalone_compile @@ -255,7 +251,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: Range | None = None, + compile_range: Range, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -319,7 +315,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: Range | None = None, + compile_range: Range, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -516,7 +512,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: Range | None = None, + compile_range: Range, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -612,8 +608,8 @@ def metrics_context(self) -> contextlib.AbstractContextManager: return contextlib.nullcontext() -def set_inductor_config(config, compile_range): - if compile_range is not None and compile_range.is_single_size(): +def set_inductor_config(config, compile_range: Range): + if compile_range.is_single_size(): # for a specific batch size, tuning triton kernel parameters # can be beneficial config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE @@ -634,7 +630,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: Range | None = None, + compile_range: Range, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_eager_compiles += 1 diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 008eba4629a3..8159b817f637 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -29,8 +29,8 @@ class PassContext: - def __init__(self, compile_range: Range | None): - self.compile_range: Range | None = compile_range + def __init__(self, compile_range: Range): + self.compile_range: Range = compile_range def get_pass_context() -> PassContext: @@ -40,7 +40,7 @@ def get_pass_context() -> PassContext: @contextmanager -def pass_context(compile_range: Range | None): +def pass_context(compile_range: Range): """A context manager that stores the current pass context, usually it is a list of sizes to specialize. """ @@ -97,7 +97,7 @@ def hash_dict(dict_: dict[Any, Any]): encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() - def is_applicable_for_range(self, compile_range: Range | None): + def is_applicable_for_range(self, compile_range: Range): return True diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 820fa9b007e3..399c998d87f8 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -128,9 +128,8 @@ def uuid(self): state["passes"].append(pass_.uuid()) state["passes"].append(self.fix_functionalization.uuid()) compile_range = get_pass_context().compile_range - if compile_range is not None: - # Include the compile range in the uuid to ensure that inductor - # recompiles the graph for the new dynamic compile range. - state["compile_range"] = str(compile_range) + # Include the compile range in the uuid to ensure that inductor + # recompiles the graph for the new dynamic compile range. + state["compile_range"] = str(compile_range) return InductorPass.hash_dict(state) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 8f34aa818a80..b59cc50f70bc 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -133,9 +133,9 @@ def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: args, self.compilation_config.inductor_compile_config, self.compilation_config, + compile_range=range_entry.compile_range, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, - compile_range=range_entry.compile_range, ) # finished compilations for all required shapes diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 6a5ee5a0efb7..84484756e7ef 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -483,7 +483,7 @@ def __init__(self, config: VllmConfig): ).register(self.patterns) self.dump_patterns(config, self.patterns) - def is_applicable_for_range(self, compile_range: Range | None) -> bool: + def is_applicable_for_range(self, compile_range: Range) -> bool: # When sequence parallelism is enabled, the residual tensor from RMSNorm # needs to be split along the sequence dimension. However, this dimension # is symbolic during piecewise compilation, and splitting symbolic shapes @@ -503,11 +503,7 @@ def is_applicable_for_range(self, compile_range: Range | None) -> bool: ): return True tp_size = get_tensor_model_parallel_world_size() - return ( - compile_range is not None - and (compile_range.is_single_size()) - and (compile_range.end % tp_size == 0) - ) + return (compile_range.is_single_size()) and (compile_range.end % tp_size == 0) @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 36bbd2b9abff..85118544117d 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -349,11 +349,11 @@ class CompilationConfig: compile_ranges_split_points: list[int] | None = None """Split points that represent compile ranges for inductor. The compile ranges are - [1, split_points[0]), - [split_points[0], split_points[1]), ..., - [split_points[-1], max_num_batched_tokens + 1). - Compile sizes are also used single element ranges: - [compile_sizes[i], compile_sizes[i] + 1). + [1, split_points[0] + 1), + [split_points[0] + 1, split_points[1] + 1), ..., + [split_points[-1] + 1, max_num_batched_tokens + 1). + Compile sizes are also used single element ranges, + the range is represented as [compile_sizes[i], compile_sizes[i] + 1). """ inductor_compile_config: dict = field(default_factory=dict) diff --git a/vllm/config/utils.py b/vllm/config/utils.py index ea97ddf125f7..20304696ffcc 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -206,7 +206,11 @@ def __hash__(self) -> int: return hash((self.start, self.end)) def __str__(self) -> str: - return f"(start={self.start}, end={self.end})" + return ( + f"[{self.start}, {self.end + 1})" + if self.is_single_size() + else f"[{self.start}, {self.end})" + ) def __repr__(self) -> str: return self.__str__() diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 4557e59a5cf8..2d71bec7c517 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -965,12 +965,13 @@ def _set_compile_ranges(self): for x in compilation_config.compile_ranges_split_points: assert isinstance(x, int) assert x > 0, f"Invalid compile range split point: {x}" + # Split points need to be inclusive of the end so we add 1. if ( max_num_batched_tokens is not None and x < max_num_batched_tokens and x > 1 ): - computed_compile_ranges_split_points.append(x) + computed_compile_ranges_split_points.append(x + 1) compilation_config.compile_ranges_split_points = sorted( computed_compile_ranges_split_points ) # type: ignore From d168de0c16e3ce0894c4ea11c54abe729b4bd6e7 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 13 Nov 2025 10:02:55 +0000 Subject: [PATCH 112/137] Make ranges inclusive-inclusive Signed-off-by: ilmarkov --- tests/compile/test_compile_ranges.py | 12 ++++++------ vllm/compilation/collective_fusion.py | 2 +- vllm/compilation/piecewise_backend.py | 9 +++++---- vllm/config/compilation.py | 15 ++++++++++----- vllm/config/utils.py | 12 ++++-------- vllm/config/vllm.py | 13 +++---------- vllm/v1/worker/gpu_worker.py | 4 ++-- 7 files changed, 31 insertions(+), 36 deletions(-) diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py index b15f90395c6a..1467d6d5b1ba 100644 --- a/tests/compile/test_compile_ranges.py +++ b/tests/compile/test_compile_ranges.py @@ -70,11 +70,11 @@ def uuid(self) -> str: def test_compile_ranges(): post_grad_pass_manager = PostGradPassManagerCheckRanges( [ - Range(start=1, end=9), + Range(start=1, end=8), Range(start=16, end=16), - Range(start=9, end=33), + Range(start=9, end=32), Range(start=64, end=64), - Range(start=33, end=8193), + Range(start=33, end=8192), ] ) torch.set_default_device("cuda") @@ -119,7 +119,7 @@ def test_compile_config_get_compile_ranges(): compilation_config=compilation_config, ) assert compilation_config.get_compile_ranges() == [ - Range(start=1, end=9), - Range(start=9, end=33), - Range(start=33, end=8193), + Range(start=1, end=8), + Range(start=9, end=32), + Range(start=33, end=8192), ] diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index bef8925661cd..2717738dd7c2 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -1188,7 +1188,7 @@ def register_patterns(self): self.disabled = False def is_applicable_for_range(self, compile_range: Range) -> bool: - return compile_range.end - 1 <= self.max_token_num + return compile_range.end <= self.max_token_num @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index b59cc50f70bc..d53fa62bdc11 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -79,10 +79,11 @@ def __init__( # We only keep compilation management inside this class directly. for size in self.compile_sizes: range = Range(start=size, end=size) - self.range_entries[range] = RangeEntry( - compile_range=range, - ) - self.to_be_compiled_ranges.add(range) + if range not in self.compile_ranges: + self.range_entries[range] = RangeEntry( + compile_range=range, + ) + self.to_be_compiled_ranges.add(range) for range in self.compile_ranges: self.range_entries[range] = RangeEntry( diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 85118544117d..42b459b6626a 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -349,11 +349,16 @@ class CompilationConfig: compile_ranges_split_points: list[int] | None = None """Split points that represent compile ranges for inductor. The compile ranges are - [1, split_points[0] + 1), - [split_points[0] + 1, split_points[1] + 1), ..., - [split_points[-1] + 1, max_num_batched_tokens + 1). + [1, split_points[0]], + [split_points[0] + 1, split_points[1]], ..., + [split_points[-1] + 1, max_num_batched_tokens]. Compile sizes are also used single element ranges, - the range is represented as [compile_sizes[i], compile_sizes[i] + 1). + the range is represented as [compile_sizes[i], compile_sizes[i]]. + + If a range overlaps with the compile size, graph for compile size + will be prioritized, i.e. if we have a range [1, 8] and a compile size 4, + graph for compile size 4 will be compiled and used instead of the graph + for range [1, 8]. """ inductor_compile_config: dict = field(default_factory=dict) @@ -964,5 +969,5 @@ def get_compile_ranges(self) -> list[Range]: if i == 0: compile_ranges.append(Range(start=1, end=s)) else: - compile_ranges.append(Range(start=split_points[i - 1], end=s)) + compile_ranges.append(Range(start=split_points[i - 1] + 1, end=s)) return compile_ranges diff --git a/vllm/config/utils.py b/vllm/config/utils.py index 20304696ffcc..c4e9a5ef6ff5 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -182,7 +182,7 @@ def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT: class Range: """ A range of numbers. - Inclusive of start, exclusive of end. + Inclusive of start, inclusive of end. """ start: int @@ -192,10 +192,10 @@ def is_single_size(self) -> bool: return self.start == self.end def __contains__(self, size: int) -> bool: - # Inclusive of start, exclusive of end + # Inclusive of start, inclusive of end if self.is_single_size(): return size == self.start - return self.start <= size < self.end + return self.start <= size <= self.end def __eq__(self, other: object) -> bool: if not isinstance(other, Range): @@ -206,11 +206,7 @@ def __hash__(self) -> int: return hash((self.start, self.end)) def __str__(self) -> str: - return ( - f"[{self.start}, {self.end + 1})" - if self.is_single_size() - else f"[{self.start}, {self.end})" - ) + return f"[{self.start}, {self.end}]" def __repr__(self) -> str: return self.__str__() diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 2d71bec7c517..6a88f96b1eea 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -938,10 +938,7 @@ def _set_compile_ranges(self): # The upper bound of the compile ranges is the max_num_batched_tokens max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens if max_num_batched_tokens is not None: - # We add 1 because the bounds checks in the compiler are exclusive - # and we want to include the max_num_batched_tokens - # in the compile range - computed_compile_ranges_split_points.append(max_num_batched_tokens + 1) + computed_compile_ranges_split_points.append(max_num_batched_tokens) # Add the compile ranges for flashinfer if compilation_config.pass_config.enable_fi_allreduce_fusion: @@ -952,26 +949,22 @@ def _set_compile_ranges(self): self.model_config.get_hidden_size() * self.model_config.dtype.itemsize ) - # We add 1 because the bounds checks in the compiler are - # exclusive and we want to include the max_token_num in the - # compile range if ( max_num_batched_tokens is not None and max_token_num < max_num_batched_tokens ): - computed_compile_ranges_split_points.append(max_token_num + 1) + computed_compile_ranges_split_points.append(max_token_num) if compilation_config.compile_ranges_split_points is not None: for x in compilation_config.compile_ranges_split_points: assert isinstance(x, int) assert x > 0, f"Invalid compile range split point: {x}" - # Split points need to be inclusive of the end so we add 1. if ( max_num_batched_tokens is not None and x < max_num_batched_tokens and x > 1 ): - computed_compile_ranges_split_points.append(x + 1) + computed_compile_ranges_split_points.append(x) compilation_config.compile_ranges_split_points = sorted( computed_compile_ranges_split_points ) # type: ignore diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 0bc9aa5ee863..04d8656e03b3 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -415,12 +415,12 @@ def compile_or_warm_up_model(self) -> None: # For each compile_range, if none of the batch sizes # in warmup_sizes or cudagraph_capture_sizes are in the range, - # add the start of the range to ensure compilation/warmup. + # add the end of the range to ensure compilation/warmup. all_sizes = set(self.vllm_config.compilation_config.cudagraph_capture_sizes) all_sizes.update(warmup_sizes) for compile_range in compile_ranges: if not any(x in compile_range for x in all_sizes): - warmup_sizes.append(compile_range.end - 1) + warmup_sizes.append(compile_range.end) # We skip EPLB here since we don't want to record dummy metrics for size in sorted(warmup_sizes, reverse=True): From 6c059198d0c1b98904807a12f903efed3e03f609 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 19 Nov 2025 14:34:38 +0000 Subject: [PATCH 113/137] Add test for inductor cache hits Signed-off-by: ilmarkov --- tests/compile/test_compile_ranges.py | 50 +++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py index 1467d6d5b1ba..9ee23b6627da 100644 --- a/tests/compile/test_compile_ranges.py +++ b/tests/compile/test_compile_ranges.py @@ -98,7 +98,7 @@ def test_compile_ranges(): model = TestModel(vllm_config=vllm_config, prefix="").eval() # Number of compilations: 3 for each compile range + 2 compile sizes batch_sizes = [1, 4, 16, 24, 48, 64, 8192] - # A has support_torch_compile + with compilation_counter.expect( num_graphs_seen=1, num_piecewise_graphs_seen=1, @@ -123,3 +123,51 @@ def test_compile_config_get_compile_ranges(): Range(start=9, end=32), Range(start=33, end=8192), ] + + +def test_inductor_cache_compile_ranges(monkeypatch): + # To force multiple compilations, we disable the compile cache + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + post_grad_pass_manager = PostGradPassManagerCheckRanges( + ranges=[ + Range(start=1, end=8), + Range(start=9, end=8192), + ] + ) + scheduler_config = SchedulerConfig( + max_num_batched_tokens=8192, + ) + torch.set_default_device("cuda") + + def create_vllm_config(): + return VllmConfig( + scheduler_config=scheduler_config, + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + compile_ranges_split_points=[8], + inductor_compile_config={ + "post_grad_custom_post_pass": post_grad_pass_manager, + # Leave inductor cache enabled to verify that the cache is used + "force_disable_caches": False, + }, + ), + ) + + vllm_config_1 = create_vllm_config() + with set_current_vllm_config(vllm_config_1): + model1 = TestModel(vllm_config=vllm_config_1, prefix="").eval() + batch_sizes = [1, 16] + run_model(vllm_config_1, model1, batch_sizes) + # Could be 0 or 2, depending on the cache + num_call_initially = post_grad_pass_manager.num_calls + assert num_call_initially in [0, 2] + + # Create a new vllm config with the new pass context + vllm_config_2 = create_vllm_config() + with set_current_vllm_config(vllm_config_2): + model2 = TestModel(vllm_config=vllm_config_2, prefix="").eval() + batch_sizes = [1, 16] + run_model(vllm_config_2, model2, batch_sizes) + # Check that cache is used + assert post_grad_pass_manager.num_calls == num_call_initially From 3f72483d8f90bdf027a8451986c8b07ffa19df8a Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 19 Nov 2025 16:01:36 +0000 Subject: [PATCH 114/137] Address comments Signed-off-by: ilmarkov --- vllm/compilation/piecewise_backend.py | 15 +-------------- vllm/config/vllm.py | 4 ++-- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index d53fa62bdc11..f1f86b885833 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -62,13 +62,8 @@ def __init__( log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}" logger.debug_once(log_string) - self.first_run_finished = False - self.sym_shape_indices = sym_shape_indices - # the entries for different shapes that we need to compile - # self.concrete_size_entries: dict[int, RangeEntry] = {} - # the entries for ranges that we need to either self.range_entries: dict[Range, RangeEntry] = {} @@ -107,6 +102,7 @@ def fakify_args(self, args: list[Any]) -> list[Any]: # it will force specializing the whole shape # torch.compile probably should not accept # non fake tensors as example inputs! + # See issue https://github.com/vllm-project/vllm/issues/27899 fake_example_inputs = [] for node in self.graph.graph.nodes: # All place holders come first @@ -143,14 +139,6 @@ def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: self.check_for_ending_compilation() def __call__(self, *args) -> Any: - if not self.first_run_finished: - self.first_run_finished = True - self.check_for_ending_compilation() - - # Role of the general graph is taken by the last range graph - range_entry = self.range_entries[self.compile_ranges[-1]] - self._maybe_compile_for_range_entry(range_entry, args) - return range_entry.runnable(*args) runtime_shape = args[self.sym_shape_indices[0]] # First we try to find the range entry for the concrete compile size @@ -174,5 +162,4 @@ def __call__(self, *args) -> Any: ) self._maybe_compile_for_range_entry(range_entry, args) - return range_entry.runnable(*args) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 26abd3fa4a4e..5e7340884674 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -525,7 +525,7 @@ def __post_init__(self): "--kv-sharing-fast-prefill requires changes on model side for " "correctness and to realize prefill savings. " ) - + # TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands self._set_compile_ranges() disable_chunked_prefill_reasons: list[str] = [] @@ -987,7 +987,7 @@ def _set_compile_ranges(self): computed_compile_ranges_split_points.append(x) compilation_config.compile_ranges_split_points = sorted( computed_compile_ranges_split_points - ) # type: ignore + ) def recalculate_max_model_len(self, max_model_len: int): # Can only be called in try_verify_and_update_config From 9b00ebc3d8404a61ead4a187b93a77608e6a6e9d Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 20 Nov 2025 10:16:26 +0000 Subject: [PATCH 115/137] Address comments Signed-off-by: ilmarkov --- tests/compile/test_compile_ranges.py | 17 +++++++++-------- vllm/compilation/piecewise_backend.py | 27 ++++++++++++--------------- vllm/config/utils.py | 2 -- 3 files changed, 21 insertions(+), 25 deletions(-) diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py index 9ee23b6627da..fee88f91a75c 100644 --- a/tests/compile/test_compile_ranges.py +++ b/tests/compile/test_compile_ranges.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + import torch from torch import fx as fx from torch import nn @@ -47,7 +49,7 @@ def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]) model(torch.randn(batch_size, MLP_SIZE)) -class PostGradPassManagerCheckRanges(InductorPass): +class PostGradRangeChecker(InductorPass): def __init__(self, ranges: list[Range]): self.ranges = ranges self.num_calls = 0 @@ -60,15 +62,12 @@ def __call__(self, graph: fx.Graph): self.num_calls += 1 def uuid(self) -> str: - state = { - "ranges": [str(range) for range in self.ranges], - "current_compile_range": str(get_pass_context().compile_range), - } + state: dict[str, Any] = {} return InductorPass.hash_dict(state) def test_compile_ranges(): - post_grad_pass_manager = PostGradPassManagerCheckRanges( + post_grad_pass_manager = PostGradRangeChecker( [ Range(start=1, end=8), Range(start=16, end=16), @@ -129,7 +128,7 @@ def test_inductor_cache_compile_ranges(monkeypatch): # To force multiple compilations, we disable the compile cache monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") - post_grad_pass_manager = PostGradPassManagerCheckRanges( + post_grad_pass_manager = PostGradRangeChecker( ranges=[ Range(start=1, end=8), Range(start=9, end=8192), @@ -160,6 +159,7 @@ def create_vllm_config(): batch_sizes = [1, 16] run_model(vllm_config_1, model1, batch_sizes) # Could be 0 or 2, depending on the cache + # pytorch issue https://github.com/pytorch/pytorch/issues/168239 num_call_initially = post_grad_pass_manager.num_calls assert num_call_initially in [0, 2] @@ -169,5 +169,6 @@ def create_vllm_config(): model2 = TestModel(vllm_config=vllm_config_2, prefix="").eval() batch_sizes = [1, 16] run_model(vllm_config_2, model2, batch_sizes) - # Check that cache is used + # Check that cache is used, so the number of calls + # should be the same as initially assert post_grad_pass_manager.num_calls == num_call_initially diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index f1f86b885833..17e579dab17b 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -92,7 +92,7 @@ def check_for_ending_compilation(self): self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) - def fakify_args(self, args: list[Any]) -> list[Any]: + def _fakify_args(self, args: list[Any]) -> list[Any]: # We need to pass fake example_inputs, otherwise torch.compile # will fakify the example_inputs potentially causing some non dynamic # dimension to be be duck shaped to other existing shapes that have hints @@ -121,7 +121,7 @@ def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: # args are real arguments # fakify for range, real args for concrete size args = ( - self.fakify_args(args) + self._fakify_args(args) if not range_entry.compile_range.is_single_size() else args ) @@ -135,28 +135,25 @@ def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: num_graphs=self.total_piecewise_compiles, ) - # finished compilations for all required shapes self.check_for_ending_compilation() - def __call__(self, *args) -> Any: - runtime_shape = args[self.sym_shape_indices[0]] - + def _find_range_for_shape(self, runtime_shape: int) -> Range | None: # First we try to find the range entry for the concrete compile size # If not found, we search for the range entry # that contains the runtime shape. - range_found = False if runtime_shape in self.compile_sizes: - range_entry = self.range_entries[ - Range(start=runtime_shape, end=runtime_shape) - ] - range_found = True + return self.range_entries[Range(start=runtime_shape, end=runtime_shape)] else: for range in self.compile_ranges: if runtime_shape in range: - range_entry = self.range_entries[range] - range_found = True - break - assert range_found, ( + return self.range_entries[range] + return None + + def __call__(self, *args) -> Any: + runtime_shape = args[self.sym_shape_indices[0]] + range_entry = self._find_range_for_shape(runtime_shape) + + assert range_entry is not None, ( f"Shape out of considered range: {runtime_shape} " "[1, max_num_batched_tokens]" ) diff --git a/vllm/config/utils.py b/vllm/config/utils.py index de979288f264..3f9aa4261b38 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -310,8 +310,6 @@ def is_single_size(self) -> bool: def __contains__(self, size: int) -> bool: # Inclusive of start, inclusive of end - if self.is_single_size(): - return size == self.start return self.start <= size <= self.end def __eq__(self, other: object) -> bool: From 8a40ac6ce891f4545b19fe4ba0592ff8667b4117 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 20 Nov 2025 20:37:29 +0000 Subject: [PATCH 116/137] Update test Signed-off-by: ilmarkov --- tests/compile/test_compile_ranges.py | 111 +++++++++++++-------------- 1 file changed, 54 insertions(+), 57 deletions(-) diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py index fee88f91a75c..6938fb2d1d02 100644 --- a/tests/compile/test_compile_ranges.py +++ b/tests/compile/test_compile_ranges.py @@ -5,6 +5,7 @@ import torch from torch import fx as fx from torch import nn +from torch._inductor.utils import fresh_cache # This import automatically registers `torch.ops.silly.attention` import tests.compile.silly_attention # noqa @@ -67,44 +68,43 @@ def uuid(self) -> str: def test_compile_ranges(): - post_grad_pass_manager = PostGradRangeChecker( - [ - Range(start=1, end=8), - Range(start=16, end=16), - Range(start=9, end=32), - Range(start=64, end=64), - Range(start=33, end=8192), - ] - ) - torch.set_default_device("cuda") - vllm_config = VllmConfig( - scheduler_config=SchedulerConfig( - max_num_batched_tokens=8192, - ), - compilation_config=CompilationConfig( - mode=CompilationMode.VLLM_COMPILE, - compile_ranges_split_points=[8, 32], - compile_sizes=[16, 64, 128], - inductor_compile_config={ - "post_grad_custom_post_pass": post_grad_pass_manager, - # Disable inductor cache to get the number of passes correctly - "force_disable_caches": True, - }, - ), - ) + with fresh_cache(): + post_grad_pass_manager = PostGradRangeChecker( + [ + Range(start=1, end=8), + Range(start=16, end=16), + Range(start=9, end=32), + Range(start=64, end=64), + Range(start=33, end=8192), + ] + ) + torch.set_default_device("cuda") + vllm_config = VllmConfig( + scheduler_config=SchedulerConfig( + max_num_batched_tokens=8192, + ), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + compile_ranges_split_points=[8, 32], + compile_sizes=[16, 64, 128], + inductor_compile_config={ + "post_grad_custom_post_pass": post_grad_pass_manager, + }, + ), + ) - with set_current_vllm_config(vllm_config): - model = TestModel(vllm_config=vllm_config, prefix="").eval() - # Number of compilations: 3 for each compile range + 2 compile sizes - batch_sizes = [1, 4, 16, 24, 48, 64, 8192] + with set_current_vllm_config(vllm_config): + model = TestModel(vllm_config=vllm_config, prefix="").eval() + # Number of compilations: 3 for each compile range + 2 compile sizes + batch_sizes = [1, 4, 16, 24, 48, 64, 8192] - with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=1, - num_backend_compilations=5, - ): - run_model(vllm_config, model, batch_sizes) - assert post_grad_pass_manager.num_calls == 5 + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=1, + num_backend_compilations=5, + ): + run_model(vllm_config, model, batch_sizes) + assert post_grad_pass_manager.num_calls == 5 def test_compile_config_get_compile_ranges(): @@ -147,28 +147,25 @@ def create_vllm_config(): compile_ranges_split_points=[8], inductor_compile_config={ "post_grad_custom_post_pass": post_grad_pass_manager, - # Leave inductor cache enabled to verify that the cache is used - "force_disable_caches": False, }, ), ) - vllm_config_1 = create_vllm_config() - with set_current_vllm_config(vllm_config_1): - model1 = TestModel(vllm_config=vllm_config_1, prefix="").eval() - batch_sizes = [1, 16] - run_model(vllm_config_1, model1, batch_sizes) - # Could be 0 or 2, depending on the cache - # pytorch issue https://github.com/pytorch/pytorch/issues/168239 - num_call_initially = post_grad_pass_manager.num_calls - assert num_call_initially in [0, 2] - - # Create a new vllm config with the new pass context - vllm_config_2 = create_vllm_config() - with set_current_vllm_config(vllm_config_2): - model2 = TestModel(vllm_config=vllm_config_2, prefix="").eval() - batch_sizes = [1, 16] - run_model(vllm_config_2, model2, batch_sizes) - # Check that cache is used, so the number of calls - # should be the same as initially - assert post_grad_pass_manager.num_calls == num_call_initially + with fresh_cache(): + vllm_config_1 = create_vllm_config() + with set_current_vllm_config(vllm_config_1): + model1 = TestModel(vllm_config=vllm_config_1, prefix="").eval() + batch_sizes = [1, 16] + run_model(vllm_config_1, model1, batch_sizes) + assert post_grad_pass_manager.num_calls == 2 + + post_grad_pass_manager.num_calls = 0 + # Create a new vllm config with the new pass context + vllm_config_2 = create_vllm_config() + with set_current_vllm_config(vllm_config_2): + model2 = TestModel(vllm_config=vllm_config_2, prefix="").eval() + batch_sizes = [4, 32] + run_model(vllm_config_2, model2, batch_sizes) + # Check that cache is used, so the number of calls + # should be 0 + assert post_grad_pass_manager.num_calls == 0 From ef0568221ae1dcf42abb1c9ee592e6d599935d9d Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 20 Nov 2025 21:11:23 +0000 Subject: [PATCH 117/137] Address comments Signed-off-by: ilmarkov --- tests/compile/test_compile_ranges.py | 112 +++++++++++++-------------- tests/utils.py | 23 +++++- 2 files changed, 78 insertions(+), 57 deletions(-) diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py index 6938fb2d1d02..a1fdff570bcd 100644 --- a/tests/compile/test_compile_ranges.py +++ b/tests/compile/test_compile_ranges.py @@ -5,10 +5,10 @@ import torch from torch import fx as fx from torch import nn -from torch._inductor.utils import fresh_cache # This import automatically registers `torch.ops.silly.attention` import tests.compile.silly_attention # noqa +from tests.utils import use_fresh_compile_cache from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.compilation.inductor_pass import ( @@ -67,44 +67,44 @@ def uuid(self) -> str: return InductorPass.hash_dict(state) +@use_fresh_compile_cache def test_compile_ranges(): - with fresh_cache(): - post_grad_pass_manager = PostGradRangeChecker( - [ - Range(start=1, end=8), - Range(start=16, end=16), - Range(start=9, end=32), - Range(start=64, end=64), - Range(start=33, end=8192), - ] - ) - torch.set_default_device("cuda") - vllm_config = VllmConfig( - scheduler_config=SchedulerConfig( - max_num_batched_tokens=8192, - ), - compilation_config=CompilationConfig( - mode=CompilationMode.VLLM_COMPILE, - compile_ranges_split_points=[8, 32], - compile_sizes=[16, 64, 128], - inductor_compile_config={ - "post_grad_custom_post_pass": post_grad_pass_manager, - }, - ), - ) + post_grad_range_checker = PostGradRangeChecker( + [ + Range(start=1, end=8), + Range(start=16, end=16), + Range(start=9, end=32), + Range(start=64, end=64), + Range(start=33, end=8192), + ] + ) + torch.set_default_device("cuda") + vllm_config = VllmConfig( + scheduler_config=SchedulerConfig( + max_num_batched_tokens=8192, + ), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + compile_ranges_split_points=[8, 32], + compile_sizes=[16, 64, 128], + inductor_compile_config={ + "post_grad_custom_post_pass": post_grad_range_checker, + }, + ), + ) - with set_current_vllm_config(vllm_config): - model = TestModel(vllm_config=vllm_config, prefix="").eval() - # Number of compilations: 3 for each compile range + 2 compile sizes - batch_sizes = [1, 4, 16, 24, 48, 64, 8192] + with set_current_vllm_config(vllm_config): + model = TestModel(vllm_config=vllm_config, prefix="").eval() + # Number of compilations: 3 for each compile range + 2 compile sizes + batch_sizes = [1, 4, 16, 24, 48, 64, 8192] - with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=1, - num_backend_compilations=5, - ): - run_model(vllm_config, model, batch_sizes) - assert post_grad_pass_manager.num_calls == 5 + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=1, + num_backend_compilations=5, + ): + run_model(vllm_config, model, batch_sizes) + assert post_grad_range_checker.num_calls == 5 def test_compile_config_get_compile_ranges(): @@ -124,11 +124,12 @@ def test_compile_config_get_compile_ranges(): ] +@use_fresh_compile_cache def test_inductor_cache_compile_ranges(monkeypatch): # To force multiple compilations, we disable the compile cache monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") - post_grad_pass_manager = PostGradRangeChecker( + post_grad_range_checker = PostGradRangeChecker( ranges=[ Range(start=1, end=8), Range(start=9, end=8192), @@ -146,26 +147,25 @@ def create_vllm_config(): mode=CompilationMode.VLLM_COMPILE, compile_ranges_split_points=[8], inductor_compile_config={ - "post_grad_custom_post_pass": post_grad_pass_manager, + "post_grad_custom_post_pass": post_grad_range_checker, }, ), ) - with fresh_cache(): - vllm_config_1 = create_vllm_config() - with set_current_vllm_config(vllm_config_1): - model1 = TestModel(vllm_config=vllm_config_1, prefix="").eval() - batch_sizes = [1, 16] - run_model(vllm_config_1, model1, batch_sizes) - assert post_grad_pass_manager.num_calls == 2 - - post_grad_pass_manager.num_calls = 0 - # Create a new vllm config with the new pass context - vllm_config_2 = create_vllm_config() - with set_current_vllm_config(vllm_config_2): - model2 = TestModel(vllm_config=vllm_config_2, prefix="").eval() - batch_sizes = [4, 32] - run_model(vllm_config_2, model2, batch_sizes) - # Check that cache is used, so the number of calls - # should be 0 - assert post_grad_pass_manager.num_calls == 0 + vllm_config_1 = create_vllm_config() + with set_current_vllm_config(vllm_config_1): + model1 = TestModel(vllm_config=vllm_config_1, prefix="").eval() + batch_sizes = [1, 16] + run_model(vllm_config_1, model1, batch_sizes) + assert post_grad_range_checker.num_calls == 2 + + post_grad_range_checker.num_calls = 0 + # Create a new vllm config with the new pass context + vllm_config_2 = create_vllm_config() + with set_current_vllm_config(vllm_config_2): + model2 = TestModel(vllm_config=vllm_config_2, prefix="").eval() + batch_sizes = [4, 32] + run_model(vllm_config_2, model2, batch_sizes) + # Check that cache is used, so the number of calls + # should be 0 + assert post_grad_range_checker.num_calls == 0 diff --git a/tests/utils.py b/tests/utils.py index c31a2aeeb9c8..82bb4dc100f9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -17,7 +17,7 @@ import time import warnings from collections.abc import Callable, Iterable -from contextlib import ExitStack, contextmanager, suppress +from contextlib import ExitStack, contextmanager, nullcontext, suppress from multiprocessing import Process from pathlib import Path from typing import Any, Literal @@ -50,6 +50,12 @@ from vllm.utils.network_utils import get_open_port from vllm.utils.torch_utils import cuda_device_count_stateless +try: + from torch._inductor.utils import fresh_cache +except ImportError: + fresh_cache = nullcontext() + + if current_platform.is_rocm(): from amdsmi import ( amdsmi_get_gpu_vram_usage, @@ -1117,6 +1123,21 @@ def wrapper(f: Callable[_P, None]) -> Callable[_P, None]: return wrapper +def use_fresh_compile_cache(f: Callable[_P, None]) -> Callable[_P, None]: + """ + Decorator to use a fresh inductor cache for the test. + This is useful to ensure that the test is not affected by the + previous test calls. + """ + + @functools.wraps(f) + def wrapper(*args, **kwargs): + with fresh_cache(): + return f(*args, **kwargs) + + return wrapper + + async def completions_with_server_args( prompts: list[str], model_name: str, From ee89388dfb46a333ab6cc79c5f9e7d9cbae4ac85 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Fri, 21 Nov 2025 14:17:07 +0000 Subject: [PATCH 118/137] Update test utils Signed-off-by: ilmarkov --- tests/utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/utils.py b/tests/utils.py index 82bb4dc100f9..afda4685b44e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -52,7 +52,10 @@ try: from torch._inductor.utils import fresh_cache + + torch_inductor_fresh_cache_available = True except ImportError: + torch_inductor_fresh_cache_available = False fresh_cache = nullcontext() @@ -1129,6 +1132,11 @@ def use_fresh_compile_cache(f: Callable[_P, None]) -> Callable[_P, None]: This is useful to ensure that the test is not affected by the previous test calls. """ + if not torch_inductor_fresh_cache_available: + print( + "torch._inductor.utils.fresh_cache is not available, " + "the test will not use fresh inductor cache." + ) @functools.wraps(f) def wrapper(*args, **kwargs): From 925e87d2a715b7d5b328fe4c137ee538df53f55f Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Fri, 21 Nov 2025 14:45:38 +0000 Subject: [PATCH 119/137] Fix pre-commit after merge Signed-off-by: ilmarkov --- vllm/v1/worker/gpu_worker.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 30d45cc97a8b..4fd6b8d6356f 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -388,21 +388,20 @@ def compile_or_warm_up_model(self) -> None: # e.g. for the max-num-batched token size in chunked prefill. compile_sizes = self.vllm_config.compilation_config.compile_sizes warmup_sizes = compile_sizes.copy() if compile_sizes is not None else [] + capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes if ( not self.model_config.enforce_eager or self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE - ): - capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes - if capture_sizes is not None: - warmup_sizes = [x for x in warmup_sizes if x not in capture_sizes] + ) and capture_sizes is not None: + warmup_sizes = [x for x in warmup_sizes if x not in capture_sizes] compile_ranges = self.vllm_config.compilation_config.get_compile_ranges() # For each compile_range, if none of the batch sizes # in warmup_sizes or cudagraph_capture_sizes are in the range, # add the end of the range to ensure compilation/warmup. - all_sizes = set(self.vllm_config.compilation_config.cudagraph_capture_sizes) - all_sizes.update(warmup_sizes) + all_sizes = set(capture_sizes if capture_sizes is not None else []) + all_sizes.update([x for x in warmup_sizes if isinstance(x, int)]) for compile_range in compile_ranges: if not any(x in compile_range for x in all_sizes): warmup_sizes.append(compile_range.end) From 809e170e13c6ffcb8603c3a33a07f4ab4ae86370 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Fri, 21 Nov 2025 16:04:48 +0000 Subject: [PATCH 120/137] Fix tests Signed-off-by: ilmarkov --- vllm/config/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config/utils.py b/vllm/config/utils.py index 3f9aa4261b38..7b026c65a0f0 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -321,7 +321,7 @@ def __hash__(self) -> int: return hash((self.start, self.end)) def __str__(self) -> str: - return f"[{self.start}, {self.end}]" + return f"({self.start}, {self.end})" def __repr__(self) -> str: return self.__str__() From e07c939874936d1c47b57b9811031e561b2a6e4f Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Fri, 21 Nov 2025 16:24:18 +0000 Subject: [PATCH 121/137] Add fixture instead of decorator Signed-off-by: ilmarkov --- tests/compile/test_compile_ranges.py | 7 ++----- tests/conftest.py | 27 ++++++++++++++++++++++++ tests/utils.py | 31 +--------------------------- 3 files changed, 30 insertions(+), 35 deletions(-) diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py index a1fdff570bcd..d849a8617ebd 100644 --- a/tests/compile/test_compile_ranges.py +++ b/tests/compile/test_compile_ranges.py @@ -8,7 +8,6 @@ # This import automatically registers `torch.ops.silly.attention` import tests.compile.silly_attention # noqa -from tests.utils import use_fresh_compile_cache from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.compilation.inductor_pass import ( @@ -67,8 +66,7 @@ def uuid(self) -> str: return InductorPass.hash_dict(state) -@use_fresh_compile_cache -def test_compile_ranges(): +def test_compile_ranges(use_fresh_inductor_cache): post_grad_range_checker = PostGradRangeChecker( [ Range(start=1, end=8), @@ -124,8 +122,7 @@ def test_compile_config_get_compile_ranges(): ] -@use_fresh_compile_cache -def test_inductor_cache_compile_ranges(monkeypatch): +def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache): # To force multiple compilations, we disable the compile cache monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") diff --git a/tests/conftest.py b/tests/conftest.py index b17081352edc..49b78844cadd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -66,6 +66,14 @@ from vllm.utils.collection_utils import is_list_of from vllm.utils.torch_utils import set_default_torch_num_threads +try: + from torch._inductor.utils import fresh_cache + + torch_inductor_fresh_cache_available = True +except ImportError: + torch_inductor_fresh_cache_available = False + + logger = init_logger(__name__) _TEST_DIR = os.path.dirname(__file__) @@ -1397,3 +1405,22 @@ def disable_deepgemm_ue8m0(monkeypatch): # Clear cache so the next time it is used it is processed with the # default VLLM_USE_DEEP_GEMM_E8M0 setting. is_deep_gemm_e8m0_used.cache_clear() + + +@pytest.fixture +def use_fresh_inductor_cache(): + """ + Use a fresh inductor cache for the test. + This is useful to ensure that the test is not affected by the + previous test calls. + """ + if not torch_inductor_fresh_cache_available: + print( + "torch._inductor.utils.fresh_cache is not available, " + "the test will not use fresh inductor cache." + ) + yield + return + + with fresh_cache(): + yield diff --git a/tests/utils.py b/tests/utils.py index afda4685b44e..c31a2aeeb9c8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -17,7 +17,7 @@ import time import warnings from collections.abc import Callable, Iterable -from contextlib import ExitStack, contextmanager, nullcontext, suppress +from contextlib import ExitStack, contextmanager, suppress from multiprocessing import Process from pathlib import Path from typing import Any, Literal @@ -50,15 +50,6 @@ from vllm.utils.network_utils import get_open_port from vllm.utils.torch_utils import cuda_device_count_stateless -try: - from torch._inductor.utils import fresh_cache - - torch_inductor_fresh_cache_available = True -except ImportError: - torch_inductor_fresh_cache_available = False - fresh_cache = nullcontext() - - if current_platform.is_rocm(): from amdsmi import ( amdsmi_get_gpu_vram_usage, @@ -1126,26 +1117,6 @@ def wrapper(f: Callable[_P, None]) -> Callable[_P, None]: return wrapper -def use_fresh_compile_cache(f: Callable[_P, None]) -> Callable[_P, None]: - """ - Decorator to use a fresh inductor cache for the test. - This is useful to ensure that the test is not affected by the - previous test calls. - """ - if not torch_inductor_fresh_cache_available: - print( - "torch._inductor.utils.fresh_cache is not available, " - "the test will not use fresh inductor cache." - ) - - @functools.wraps(f) - def wrapper(*args, **kwargs): - with fresh_cache(): - return f(*args, **kwargs) - - return wrapper - - async def completions_with_server_args( prompts: list[str], model_name: str, From f4db45c4fe7dae17c1877f7616a0536ba5350585 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Sun, 23 Nov 2025 09:45:41 +0000 Subject: [PATCH 122/137] Fix re-used compilation config Signed-off-by: ilmarkov --- vllm/compilation/inductor_pass.py | 10 ++++++++++ vllm/compilation/pass_manager.py | 20 +++++++++++++++----- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 8159b817f637..044bb0af866b 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -39,6 +39,16 @@ def get_pass_context() -> PassContext: return _pass_context +def try_get_pass_context() -> PassContext | None: + """ + Try to get the current pass context. + Return None if the pass context is not set. + The pass context can be not set if it compilation config is re-used + for multiple model compilations which is only used in the tests. + """ + return _pass_context + + @contextmanager def pass_context(compile_range: Range): """A context manager that stores the current pass context, diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 1165c119cb94..9f9acc581d8c 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -24,7 +24,12 @@ from .collective_fusion import AllReduceFusionPass, AsyncTPPass from .fix_functionalization import FixFunctionalizationPass -from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context +from .inductor_pass import ( + CustomGraphPass, + InductorPass, + get_pass_context, + try_get_pass_context, +) from .noop_elimination import NoOpEliminationPass logger = init_logger(__name__) @@ -131,9 +136,14 @@ def uuid(self): for pass_ in self.passes: state["passes"].append(pass_.uuid()) state["passes"].append(self.fix_functionalization.uuid()) - compile_range = get_pass_context().compile_range - # Include the compile range in the uuid to ensure that inductor - # recompiles the graph for the new dynamic compile range. - state["compile_range"] = str(compile_range) + + # If there is no pass context, we are likely re-using compilation config + # In this case uuid() is called not at compilation time, + # we don't have the compile range information. + pass_context = try_get_pass_context() + if pass_context is not None: + # Include the compile range in the uuid to ensure that inductor + # recompiles the graph for the new dynamic compile range. + state["compile_range"] = str(pass_context.compile_range) return InductorPass.hash_dict(state) From 4f280ce27113dca5c615185fd2a91ec676032734 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Sun, 23 Nov 2025 21:22:04 +0000 Subject: [PATCH 123/137] Fix e2e Signed-off-by: ilmarkov --- tests/compile/distributed/test_fusions_e2e.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/compile/distributed/test_fusions_e2e.py b/tests/compile/distributed/test_fusions_e2e.py index 661172e1965b..126636dc756d 100644 --- a/tests/compile/distributed/test_fusions_e2e.py +++ b/tests/compile/distributed/test_fusions_e2e.py @@ -298,7 +298,9 @@ def test_tp2_attn_quant_allreduce_rmsnorm( r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", log_holder.text, ) - assert len(log_matches) == 2, log_holder.text + # 2 for 2 compile ranges + # (global compile range is split due to enable_fi_allreduce_fusion) + assert len(log_matches) == 4, log_holder.text assert int(log_matches[0]) == matches.attention_fusion assert int(log_matches[1]) == matches.attention_fusion From b27f89d12d47b42cfa41dcd22a932bbc7d439410 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 24 Nov 2025 10:25:22 +0000 Subject: [PATCH 124/137] Fix e2e adapt to number of compile ranges Signed-off-by: ilmarkov --- tests/compile/distributed/test_fusions_e2e.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/compile/distributed/test_fusions_e2e.py b/tests/compile/distributed/test_fusions_e2e.py index 126636dc756d..38af27fd9928 100644 --- a/tests/compile/distributed/test_fusions_e2e.py +++ b/tests/compile/distributed/test_fusions_e2e.py @@ -298,9 +298,11 @@ def test_tp2_attn_quant_allreduce_rmsnorm( r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", log_holder.text, ) - # 2 for 2 compile ranges - # (global compile range is split due to enable_fi_allreduce_fusion) - assert len(log_matches) == 4, log_holder.text + # 2 for each compile range + # (global compile range can be split due to enable_fi_allreduce_fusion) + assert len(log_matches) == 2 * len(compilation_config.get_compile_ranges()), ( + log_holder.text + ) assert int(log_matches[0]) == matches.attention_fusion assert int(log_matches[1]) == matches.attention_fusion @@ -448,7 +450,6 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg # No cudagraphs by default if compilation_config.cudagraph_mode is None: compilation_config.cudagraph_mode = CUDAGraphMode.NONE - llm = LLM( model=model, compilation_config=compilation_config, @@ -461,3 +462,9 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + # Get the compile ranges split points after vllm config post init + # in order to compute compile ranges correctly + compilation_config.compile_ranges_split_points = ( + llm.llm_engine.vllm_config.compilation_config.compile_ranges_split_points + ) From cc8f2f8841dfe0554916c1a182254b3087070716 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 25 Nov 2025 22:12:54 +0000 Subject: [PATCH 125/137] Slight fix of test Signed-off-by: ilmarkov --- tests/compile/distributed/test_fusions_e2e.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/compile/distributed/test_fusions_e2e.py b/tests/compile/distributed/test_fusions_e2e.py index 26412e948a13..e2e76885d8f4 100644 --- a/tests/compile/distributed/test_fusions_e2e.py +++ b/tests/compile/distributed/test_fusions_e2e.py @@ -315,8 +315,7 @@ def test_tp2_attn_quant_allreduce_rmsnorm( log_holder.text ) - assert int(log_matches[0]) == matches.attention_fusion - assert int(log_matches[1]) == matches.attention_fusion + assert all(int(log_match) == matches.attention_fusion for log_match in log_matches) log_matches = re.findall( r"collective_fusion.py:\d+] Replaced (\d+) patterns", From d1dd4db4627bb90da5fb017111dd1c5a941c9d35 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 25 Nov 2025 22:33:51 +0000 Subject: [PATCH 126/137] Fix tests after refactor Signed-off-by: ilmarkov --- vllm/compilation/backends.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index cb839777aca1..5e13f93d9766 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -538,8 +538,13 @@ def configure_post_pass(self): ) else: # Config should automatically wrap all inductor passes - assert isinstance(self.inductor_config[self.pass_key], InductorPass) - self.pass_manager.add(self.inductor_config[self.pass_key]) + assert isinstance( + self.compilation_config.inductor_compile_config[self.pass_key], + InductorPass, + ) + self.pass_manager.add( + self.compilation_config.inductor_compile_config[self.pass_key] + ) self.inductor_config[self.pass_key] = self.pass_manager def __call__( From a2b67a426850ded75b538c0d62faa32c4b211c49 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 25 Nov 2025 22:37:12 +0000 Subject: [PATCH 127/137] Simplify Signed-off-by: ilmarkov --- vllm/config/compilation.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 28ec5824fb1c..532934cda425 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -1038,10 +1038,7 @@ def get_compile_ranges(self) -> list[Range]: if self.compile_ranges_split_points is None: return [] split_points = sorted(set(self.compile_ranges_split_points)) - compile_ranges = [] - for i, s in enumerate(split_points): - if i == 0: - compile_ranges.append(Range(start=1, end=s)) - else: - compile_ranges.append(Range(start=split_points[i - 1] + 1, end=s)) - return compile_ranges + return [ + Range(start=s + 1, end=e) + for s, e in zip([0] + split_points[:-1], split_points) + ] From 077636469a7c61c1500e2e6ac4e66d1b383face2 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 26 Nov 2025 09:26:12 +0000 Subject: [PATCH 128/137] Address comments Signed-off-by: ilmarkov --- tests/conftest.py | 15 +-------------- vllm/compilation/piecewise_backend.py | 4 +++- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 15fc01155b22..0b4e65a31c79 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -66,12 +66,7 @@ from vllm.utils.collection_utils import is_list_of from vllm.utils.torch_utils import set_default_torch_num_threads -try: - from torch._inductor.utils import fresh_cache - - torch_inductor_fresh_cache_available = True -except ImportError: - torch_inductor_fresh_cache_available = False +from torch._inductor.utils import fresh_cache logger = init_logger(__name__) @@ -1441,13 +1436,5 @@ def use_fresh_inductor_cache(): This is useful to ensure that the test is not affected by the previous test calls. """ - if not torch_inductor_fresh_cache_available: - print( - "torch._inductor.utils.fresh_cache is not available, " - "the test will not use fresh inductor cache." - ) - yield - return - with fresh_cache(): yield diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 46b680be5f9a..129b9b5deea3 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -119,7 +119,9 @@ def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: self.to_be_compiled_ranges.remove(range_entry.compile_range) # args are real arguments - # fakify for range, real args for concrete size + # fakify for range, real args for concrete size. + # For concrete size, we clear the shape env in + # compiler_manager.compile() so no need to fakify. args = ( self._fakify_args(args) if not range_entry.compile_range.is_single_size() From ba90b9e9e3164d6bf835253af9faa9d303ee1432 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 2 Dec 2025 14:37:26 -0500 Subject: [PATCH 129/137] Only warm up model if mode=VLLM_COMPILE MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/v1/worker/gpu_worker.py | 45 ++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index d72856ef9bde..286086609237 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -15,6 +15,7 @@ import vllm.envs as envs from vllm.config import CUDAGraphMode, VllmConfig +from vllm.config.compilation import CompilationMode from vllm.distributed import ( ensure_model_parallel_initialized, init_distributed_environment, @@ -403,28 +404,28 @@ def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: self.model_runner.initialize_kv_cache(kv_cache_config) def compile_or_warm_up_model(self) -> None: - # warm up sizes that are not in cudagraph capture sizes, - # but users still want to compile for better performance, - # e.g. for the max-num-batched token size in chunked prefill. - compile_sizes = self.vllm_config.compilation_config.compile_sizes - warmup_sizes = compile_sizes.copy() if compile_sizes is not None else [] - capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes - - if ( - not self.model_config.enforce_eager - or self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE - ) and capture_sizes is not None: - warmup_sizes = [x for x in warmup_sizes if x not in capture_sizes] - compile_ranges = self.vllm_config.compilation_config.get_compile_ranges() - - # For each compile_range, if none of the batch sizes - # in warmup_sizes or cudagraph_capture_sizes are in the range, - # add the end of the range to ensure compilation/warmup. - all_sizes = set(capture_sizes if capture_sizes is not None else []) - all_sizes.update([x for x in warmup_sizes if isinstance(x, int)]) - for compile_range in compile_ranges: - if not any(x in compile_range for x in all_sizes): - warmup_sizes.append(compile_range.end) + warmup_sizes = [] + + if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE: + # warm up sizes that are not in cudagraph capture sizes, + # but users still want to compile for better performance, + # e.g. for the max-num-batched token size in chunked prefill. + compile_sizes = self.vllm_config.compilation_config.compile_sizes + warmup_sizes = compile_sizes.copy() if compile_sizes is not None else [] + + if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes + warmup_sizes = [x for x in warmup_sizes if x not in cg_sizes] + + compile_ranges = self.vllm_config.compilation_config.get_compile_ranges() + # For each compile_range, if none of the batch sizes + # in warmup_sizes or cudagraph_capture_sizes are in the range, + # add the end of the range to ensure compilation/warmup. + all_sizes = set(cg_sizes if cg_sizes is not None else []) + all_sizes.update([x for x in warmup_sizes if isinstance(x, int)]) + for compile_range in compile_ranges: + if not any(x in compile_range for x in all_sizes): + warmup_sizes.append(compile_range.end) # We skip EPLB here since we don't want to record dummy metrics for size in sorted(warmup_sizes, reverse=True): From 771203f2d900862b9e8a7807e68b18734419e3f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 2 Dec 2025 14:40:55 -0500 Subject: [PATCH 130/137] Fix capture-sizes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/v1/worker/gpu_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 286086609237..5b2a4fad5e25 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -412,9 +412,9 @@ def compile_or_warm_up_model(self) -> None: # e.g. for the max-num-batched token size in chunked prefill. compile_sizes = self.vllm_config.compilation_config.compile_sizes warmup_sizes = compile_sizes.copy() if compile_sizes is not None else [] + cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: - cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes warmup_sizes = [x for x in warmup_sizes if x not in cg_sizes] compile_ranges = self.vllm_config.compilation_config.get_compile_ranges() From 0e0eab9d41c15772e2b8fded9624b62e42f41942 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 2 Dec 2025 14:46:18 -0500 Subject: [PATCH 131/137] Fix doc range MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/compilation/compiler_interface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index f9da8dffda0a..ab56d3561c56 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -70,8 +70,8 @@ def compile( """ Compile the graph with the given example inputs and compiler config, with a range. The `compile_range` specifies the range of the inputs, - it could be concrete size (if compile_sizes is provided), e.g. [4, 4) - or a range [4, 5). + it could be concrete size (if compile_sizes is provided), e.g. [4, 4] + or a range [5, 8]. Right now we only support one variable in ranges for all inputs, which is the batchsize (number of tokens) during inference. From 3d2c36b98cb3feb5002689f6dc306178a3aab16f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 2 Dec 2025 14:48:48 -0500 Subject: [PATCH 132/137] pre-commit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/v1/worker/gpu_worker.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 5b2a4fad5e25..aadea3d04aae 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -412,9 +412,11 @@ def compile_or_warm_up_model(self) -> None: # e.g. for the max-num-batched token size in chunked prefill. compile_sizes = self.vllm_config.compilation_config.compile_sizes warmup_sizes = compile_sizes.copy() if compile_sizes is not None else [] - cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes + cg_sizes = [] if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes + cg_sizes = [] if cg_sizes is None else cg_sizes warmup_sizes = [x for x in warmup_sizes if x not in cg_sizes] compile_ranges = self.vllm_config.compilation_config.get_compile_ranges() From 18ff16ed7c0a0ec459666f670f059f82527dc2cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 2 Dec 2025 19:37:51 -0500 Subject: [PATCH 133/137] Fix types for precommit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/v1/worker/gpu_worker.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index aadea3d04aae..1d06850d569d 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -412,18 +412,18 @@ def compile_or_warm_up_model(self) -> None: # e.g. for the max-num-batched token size in chunked prefill. compile_sizes = self.vllm_config.compilation_config.compile_sizes warmup_sizes = compile_sizes.copy() if compile_sizes is not None else [] - cg_sizes = [] + cg_capture_sizes : list[int] = [] if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes - cg_sizes = [] if cg_sizes is None else cg_sizes - warmup_sizes = [x for x in warmup_sizes if x not in cg_sizes] + cg_capture_sizes = [] if cg_sizes is None else cg_sizes + warmup_sizes = [x for x in warmup_sizes if x not in cg_capture_sizes] compile_ranges = self.vllm_config.compilation_config.get_compile_ranges() # For each compile_range, if none of the batch sizes # in warmup_sizes or cudagraph_capture_sizes are in the range, # add the end of the range to ensure compilation/warmup. - all_sizes = set(cg_sizes if cg_sizes is not None else []) + all_sizes = set(cg_capture_sizes) all_sizes.update([x for x in warmup_sizes if isinstance(x, int)]) for compile_range in compile_ranges: if not any(x in compile_range for x in all_sizes): From 6bc8258cf9f42364e7585af47ed7f99a1a8854b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 2 Dec 2025 19:39:16 -0500 Subject: [PATCH 134/137] Update vllm/v1/worker/gpu_worker.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/v1/worker/gpu_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 1d06850d569d..54c669b8a0ac 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -412,7 +412,7 @@ def compile_or_warm_up_model(self) -> None: # e.g. for the max-num-batched token size in chunked prefill. compile_sizes = self.vllm_config.compilation_config.compile_sizes warmup_sizes = compile_sizes.copy() if compile_sizes is not None else [] - cg_capture_sizes : list[int] = [] + cg_capture_sizes: list[int] = [] if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes From f4c0ae79a2af94bddd79ebd1fbda56e580a80c52 Mon Sep 17 00:00:00 2001 From: ProExpertProg Date: Wed, 3 Dec 2025 21:53:20 +0000 Subject: [PATCH 135/137] Check that the pass was skipped in other range Signed-off-by: ProExpertProg --- tests/compile/distributed/test_fusions_e2e.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/compile/distributed/test_fusions_e2e.py b/tests/compile/distributed/test_fusions_e2e.py index 141e6f3461b4..a205ec33d055 100644 --- a/tests/compile/distributed/test_fusions_e2e.py +++ b/tests/compile/distributed/test_fusions_e2e.py @@ -315,6 +315,12 @@ def test_tp2_attn_quant_allreduce_rmsnorm( assert int(log_matches[0]) == matches.allreduce_fusion assert int(log_matches[1]) == matches.allreduce_fusion + log_matches = re.findall( + r"pass_manager.py:\d+] Skipping .*AllReduceFusionPass.* with compile range", + log_holder.text, + ) + assert len(log_matches) == 2, log_holder.text + @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( From 6d42595eba7b65382eaa469df7ae380c401b83de Mon Sep 17 00:00:00 2001 From: ProExpertProg Date: Thu, 4 Dec 2025 19:37:44 +0000 Subject: [PATCH 136/137] Fix e2e test failure Signed-off-by: ProExpertProg --- tests/compile/distributed/test_fusions_e2e.py | 11 ++++++----- vllm/config/vllm.py | 5 +++++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/compile/distributed/test_fusions_e2e.py b/tests/compile/distributed/test_fusions_e2e.py index a205ec33d055..8825871d7ddc 100644 --- a/tests/compile/distributed/test_fusions_e2e.py +++ b/tests/compile/distributed/test_fusions_e2e.py @@ -102,7 +102,7 @@ class ModelBackendTestCase(NamedTuple): ), ModelBackendTestCase( model_name="Qwen/Qwen3-30B-A3B", - model_kwargs=dict(max_model_len=1024), + model_kwargs=dict(max_model_len=1024, load_format="dummy"), backend=AttentionBackendEnum.TRITON_ATTN, matches=Matches( attention_fusion=0, @@ -300,9 +300,10 @@ def test_tp2_attn_quant_allreduce_rmsnorm( ) # 2 for each compile range # (global compile range can be split due to fuse_allreduce_rmsnorm) - assert len(log_matches) == 2 * len(compilation_config.get_compile_ranges()), ( - log_holder.text - ) + num_compile_ranges = len(compilation_config.get_compile_ranges()) + assert num_compile_ranges in [1, 2] + + assert len(log_matches) == 2 * num_compile_ranges, log_holder.text assert all(int(log_match) == matches.attention_fusion for log_match in log_matches) @@ -319,7 +320,7 @@ def test_tp2_attn_quant_allreduce_rmsnorm( r"pass_manager.py:\d+] Skipping .*AllReduceFusionPass.* with compile range", log_holder.text, ) - assert len(log_matches) == 2, log_holder.text + assert len(log_matches) == 2 * (num_compile_ranges - 1), log_holder.text @multi_gpu_test(num_gpus=2) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index e668d72588c0..af6c1422f4bd 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1161,6 +1161,11 @@ def _set_compile_ranges(self): and max_token_num < max_num_batched_tokens ): computed_compile_ranges_split_points.append(max_token_num) + else: + logger.debug( + "Max num batched tokens below allreduce-rms fusion threshold, " + "allreduce-rms fusion will be enabled for all num_tokens." + ) if compilation_config.compile_ranges_split_points is not None: for x in compilation_config.compile_ranges_split_points: From 47bc80142d243e8c2eae8ad3174eeca86eff45fb Mon Sep 17 00:00:00 2001 From: ProExpertProg Date: Thu, 4 Dec 2025 19:44:00 +0000 Subject: [PATCH 137/137] Fix e2e test failure Signed-off-by: ProExpertProg --- tests/compile/distributed/test_fusions_e2e.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/compile/distributed/test_fusions_e2e.py b/tests/compile/distributed/test_fusions_e2e.py index 8825871d7ddc..75a81efedea3 100644 --- a/tests/compile/distributed/test_fusions_e2e.py +++ b/tests/compile/distributed/test_fusions_e2e.py @@ -102,7 +102,7 @@ class ModelBackendTestCase(NamedTuple): ), ModelBackendTestCase( model_name="Qwen/Qwen3-30B-A3B", - model_kwargs=dict(max_model_len=1024, load_format="dummy"), + model_kwargs=dict(max_model_len=1024), backend=AttentionBackendEnum.TRITON_ATTN, matches=Matches( attention_fusion=0,