From 948efa4f711ae2487d2e1d3735699a4b57414dfc Mon Sep 17 00:00:00 2001 From: "Lucia (Lu) Fang" Date: Thu, 16 Oct 2025 10:02:58 -0700 Subject: [PATCH 1/5] Passing only necessary compilation config to inductor pass config (#27041) Summary: Pull Request resolved: https://github.com/vllm-project/vllm/pull/27041 we should not pass the weakref to compilation_config, which include static_forward_context that will holds the pointers to the model layers (e.g. moe, attention), which is dangerous, as this will be passed as config to torch.compile Test Plan: local tests Differential Revision: D84790018 Signed-off-by: Lu Fang --- vllm/compilation/vllm_inductor_pass.py | 28 +++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index beac928b5d71..2ad802c26d3d 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -3,7 +3,7 @@ import functools import operator import time -import weakref +from dataclasses import dataclass from typing import ClassVar import regex as re @@ -11,7 +11,7 @@ from torch._dynamo.utils import lazy_format_graph_code from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter -from vllm.config import VllmConfig +from vllm.config import CompilationConfig, VllmConfig from vllm.logger import init_logger from .inductor_pass import InductorPass @@ -19,6 +19,26 @@ logger = init_logger(__name__) +@dataclass +class SimplifiedCompilationConfig: + splitting_ops: list[str] | None = None + use_inductor_graph_partition: bool = False + compile_sizes: list[int | str] | None = None + + +def copy_necessary_config_for_pass( + config: CompilationConfig, +) -> SimplifiedCompilationConfig: + """Get only the necessary CompilationConfig for the inductor pass, since + full `CompilationConfig` contains pointer to model which is unsafe. + """ + return SimplifiedCompilationConfig( + splitting_ops=config.splitting_ops, + use_inductor_graph_partition=config.use_inductor_graph_partition, + compile_sizes=config.compile_sizes, + ) + + class VllmInductorPass(InductorPass): """ An inductor pass with access to vLLM PassConfig. @@ -29,7 +49,9 @@ class VllmInductorPass(InductorPass): """Keep track of pass index for debug dump ordering.""" def __init__(self, config: VllmConfig): - self.compilation_config = weakref.proxy(config.compilation_config) + self.compilation_config = copy_necessary_config_for_pass( + config.compilation_config + ) self.pass_config = config.compilation_config.pass_config self.model_dtype = config.model_config.dtype if config.model_config else None self.device = config.device_config.device if config.device_config else None From 1fd5983659401cac85ecd009bd8886b3117d4ed1 Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Thu, 16 Oct 2025 10:16:18 -0700 Subject: [PATCH 2/5] rename Signed-off-by: Lu Fang --- vllm/compilation/vllm_inductor_pass.py | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index 2ad802c26d3d..5fa7c304a404 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -11,7 +11,7 @@ from torch._dynamo.utils import lazy_format_graph_code from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter -from vllm.config import CompilationConfig, VllmConfig +from vllm.config import VllmConfig from vllm.logger import init_logger from .inductor_pass import InductorPass @@ -20,25 +20,12 @@ @dataclass -class SimplifiedCompilationConfig: +class InductorCompilationConfig: splitting_ops: list[str] | None = None use_inductor_graph_partition: bool = False compile_sizes: list[int | str] | None = None -def copy_necessary_config_for_pass( - config: CompilationConfig, -) -> SimplifiedCompilationConfig: - """Get only the necessary CompilationConfig for the inductor pass, since - full `CompilationConfig` contains pointer to model which is unsafe. - """ - return SimplifiedCompilationConfig( - splitting_ops=config.splitting_ops, - use_inductor_graph_partition=config.use_inductor_graph_partition, - compile_sizes=config.compile_sizes, - ) - - class VllmInductorPass(InductorPass): """ An inductor pass with access to vLLM PassConfig. @@ -49,8 +36,12 @@ class VllmInductorPass(InductorPass): """Keep track of pass index for debug dump ordering.""" def __init__(self, config: VllmConfig): - self.compilation_config = copy_necessary_config_for_pass( - config.compilation_config + # Get only the necessary CompilationConfig for the inductor pass, since + # full `CompilationConfig` contains pointer to model which is unsafe. + self.compilation_config = InductorCompilationConfig( + splitting_ops=config.splitting_ops, + use_inductor_graph_partition=config.use_inductor_graph_partition, + compile_sizes=config.compile_sizes, ) self.pass_config = config.compilation_config.pass_config self.model_dtype = config.model_config.dtype if config.model_config else None From 74a3cc7d6616cd656c9b70893ae22be2b890268e Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Thu, 16 Oct 2025 10:42:02 -0700 Subject: [PATCH 3/5] add tests and fix Signed-off-by: Lu Fang --- tests/compile/test_async_tp.py | 9 +++++++++ tests/compile/test_sequence_parallelism.py | 8 ++++++++ 2 files changed, 17 insertions(+) diff --git a/tests/compile/test_async_tp.py b/tests/compile/test_async_tp.py index 60856f5a5806..cce99d0c4f4c 100644 --- a/tests/compile/test_async_tp.py +++ b/tests/compile/test_async_tp.py @@ -341,6 +341,15 @@ def async_tp_pass_on_test_model( async_tp_pass = AsyncTPPass(vllm_config) backend = TestBackend(async_tp_pass) + assert ( + async_tp_pass.compilation_config.splitting_ops + == vllm_config.compilation_config.splitting_ops + ) + assert ( + async_tp_pass.compilation_config.use_inductor_graph_partition + == vllm_config.compilation_config.use_inductor_graph_partition + ) + model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor hidden_states = torch.randn( diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index 6abab88e6369..9969a629c008 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -285,6 +285,14 @@ def sequence_parallelism_pass_on_test_model( noop_pass = NoOpEliminationPass(vllm_config) sequence_parallelism_pass = SequenceParallelismPass(vllm_config) + assert ( + sequence_parallelism_pass.compilation_config.splitting_ops + == vllm_config.compilation_config.splitting_ops + ) + assert ( + sequence_parallelism_pass.compilation_config.use_inductor_graph_partition + == vllm_config.compilation_config.use_inductor_graph_partition + ) func_pass = FixFunctionalizationPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config) From f66288a5f2140038cb04d330b8bc3cc142c7a552 Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Thu, 16 Oct 2025 11:06:39 -0700 Subject: [PATCH 4/5] add more tesets Signed-off-by: Lu Fang --- tests/compile/test_config.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 7f51c763da73..87b5d167d168 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy + import pytest from vllm.compilation.counter import compilation_counter +from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.config.compilation import CompilationMode from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer @@ -25,6 +28,20 @@ def test_use_cudagraphs_dynamic(): assert vllm_config.compilation_config.use_cudagraph +def test_copy_pass(): + vllm_config = VllmConfig() + inductor_pass = FixFunctionalizationPass(vllm_config) + copied_inductor_pass = copy.deepcopy(inductor_pass) + assert ( + copied_inductor_pass.compilation_config.use_inductor_graph_partition + == vllm_config.compilation_config.use_inductor_graph_partition + ) + assert ( + copied_inductor_pass.compilation_config.splitting_ops + == vllm_config.compilation_config.splitting_ops + ) + + def test_custom_op(): # proper syntax _ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"]) From 35a7b8c9477087b923f4bee14b7d0dfd196b138b Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Thu, 16 Oct 2025 15:00:26 -0700 Subject: [PATCH 5/5] fix tests Signed-off-by: Lu Fang --- vllm/compilation/vllm_inductor_pass.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index 5fa7c304a404..7ef2dddcb407 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -23,7 +23,6 @@ class InductorCompilationConfig: splitting_ops: list[str] | None = None use_inductor_graph_partition: bool = False - compile_sizes: list[int | str] | None = None class VllmInductorPass(InductorPass): @@ -39,9 +38,8 @@ def __init__(self, config: VllmConfig): # Get only the necessary CompilationConfig for the inductor pass, since # full `CompilationConfig` contains pointer to model which is unsafe. self.compilation_config = InductorCompilationConfig( - splitting_ops=config.splitting_ops, - use_inductor_graph_partition=config.use_inductor_graph_partition, - compile_sizes=config.compile_sizes, + splitting_ops=config.compilation_config.splitting_ops, + use_inductor_graph_partition=config.compilation_config.use_inductor_graph_partition, ) self.pass_config = config.compilation_config.pass_config self.model_dtype = config.model_config.dtype if config.model_config else None