Skip to content

Commit d161999

Browse files
Lucia (Lu) Fangfacebook-github-bot
authored andcommitted
Passing only necessary compilation config to inductor pass config
Summary: we should not pass the weakref to compilation_config, which include static_forward_context that will holds the pointers to the model layers (e.g. moe, attention), which is dangerous, as this will be passed as config to torch.compile Differential Revision: D84790018
1 parent 9f4e309 commit d161999

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

vllm/compilation/vllm_inductor_pass.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,34 @@
33
import functools
44
import operator
55
import time
6-
import weakref
76
from typing import ClassVar
87

98
import regex as re
109
import torch
10+
from dataclasses import dataclass
1111
from torch._dynamo.utils import lazy_format_graph_code
1212
from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
1313

14-
from vllm.config import VllmConfig
14+
from vllm.config import CompilationConfig, VllmConfig
1515
from vllm.logger import init_logger
1616

1717
from .inductor_pass import InductorPass
1818

1919
logger = init_logger(__name__)
2020

21+
@dataclass
22+
class SimplifiedCompilationConfig:
23+
splitting_ops: list[str] | None = None
24+
use_inductor_graph_partition: bool = False
25+
compile_sizes: bool = False
26+
27+
def copy_necessary_config_for_pass(config: CompilationConfig) -> SimplifiedCompilationConfig:
28+
"""Get the necessary CompilationConfig for the current pass."""
29+
return SimplifiedCompilationConfig(
30+
splitting_ops=config.splitting_ops,
31+
use_inductor_graph_partition=config.use_inductor_graph_partition,
32+
compile_sizes=config.compile_sizes,
33+
)
2134

2235
class VllmInductorPass(InductorPass):
2336
"""
@@ -29,7 +42,7 @@ class VllmInductorPass(InductorPass):
2942
"""Keep track of pass index for debug dump ordering."""
3043

3144
def __init__(self, config: VllmConfig):
32-
self.compilation_config = weakref.proxy(config.compilation_config)
45+
self.compilation_config = copy_necessary_config_for_pass(config.compilation_config)
3346
self.pass_config = config.compilation_config.pass_config
3447
self.model_dtype = config.model_config.dtype if config.model_config else None
3548
self.device = config.device_config.device if config.device_config else None

0 commit comments

Comments
 (0)