Skip to content

Commit 948efa4

Browse files
Lucia (Lu) Fangluccafong
authored andcommitted
Passing only necessary compilation config to inductor pass config (#27041)
Summary: Pull Request resolved: #27041 we should not pass the weakref to compilation_config, which include static_forward_context that will holds the pointers to the model layers (e.g. moe, attention), which is dangerous, as this will be passed as config to torch.compile Test Plan: local tests Differential Revision: D84790018 Signed-off-by: Lu Fang <[email protected]>
1 parent 7bb736d commit 948efa4

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

vllm/compilation/vllm_inductor_pass.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,42 @@
33
import functools
44
import operator
55
import time
6-
import weakref
6+
from dataclasses import dataclass
77
from typing import ClassVar
88

99
import regex as re
1010
import torch
1111
from torch._dynamo.utils import lazy_format_graph_code
1212
from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
1313

14-
from vllm.config import VllmConfig
14+
from vllm.config import CompilationConfig, VllmConfig
1515
from vllm.logger import init_logger
1616

1717
from .inductor_pass import InductorPass
1818

1919
logger = init_logger(__name__)
2020

2121

22+
@dataclass
23+
class SimplifiedCompilationConfig:
24+
splitting_ops: list[str] | None = None
25+
use_inductor_graph_partition: bool = False
26+
compile_sizes: list[int | str] | None = None
27+
28+
29+
def copy_necessary_config_for_pass(
30+
config: CompilationConfig,
31+
) -> SimplifiedCompilationConfig:
32+
"""Get only the necessary CompilationConfig for the inductor pass, since
33+
full `CompilationConfig` contains pointer to model which is unsafe.
34+
"""
35+
return SimplifiedCompilationConfig(
36+
splitting_ops=config.splitting_ops,
37+
use_inductor_graph_partition=config.use_inductor_graph_partition,
38+
compile_sizes=config.compile_sizes,
39+
)
40+
41+
2242
class VllmInductorPass(InductorPass):
2343
"""
2444
An inductor pass with access to vLLM PassConfig.
@@ -29,7 +49,9 @@ class VllmInductorPass(InductorPass):
2949
"""Keep track of pass index for debug dump ordering."""
3050

3151
def __init__(self, config: VllmConfig):
32-
self.compilation_config = weakref.proxy(config.compilation_config)
52+
self.compilation_config = copy_necessary_config_for_pass(
53+
config.compilation_config
54+
)
3355
self.pass_config = config.compilation_config.pass_config
3456
self.model_dtype = config.model_config.dtype if config.model_config else None
3557
self.device = config.device_config.device if config.device_config else None

0 commit comments

Comments
 (0)