-
-
Notifications
You must be signed in to change notification settings - Fork 12k
[torch.compile] Passing only necessary compilation config to inductor pass config #27041
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,22 +3,42 @@ | |
| import functools | ||
| import operator | ||
| import time | ||
| import weakref | ||
| from dataclasses import dataclass | ||
| from typing import ClassVar | ||
|
|
||
| import regex as re | ||
| import torch | ||
| from torch._dynamo.utils import lazy_format_graph_code | ||
| from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter | ||
|
|
||
| from vllm.config import VllmConfig | ||
| from vllm.config import CompilationConfig, VllmConfig | ||
| from vllm.logger import init_logger | ||
|
|
||
| from .inductor_pass import InductorPass | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| @dataclass | ||
| class SimplifiedCompilationConfig: | ||
| splitting_ops: list[str] | None = None | ||
| use_inductor_graph_partition: bool = False | ||
| compile_sizes: list[int | str] | None = None | ||
|
|
||
|
|
||
| def copy_necessary_config_for_pass( | ||
luccafong marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| config: CompilationConfig, | ||
| ) -> SimplifiedCompilationConfig: | ||
| """Get only the necessary CompilationConfig for the inductor pass, since | ||
| full `CompilationConfig` contains pointer to model which is unsafe. | ||
| """ | ||
| return SimplifiedCompilationConfig( | ||
| splitting_ops=config.splitting_ops, | ||
| use_inductor_graph_partition=config.use_inductor_graph_partition, | ||
| compile_sizes=config.compile_sizes, | ||
| ) | ||
|
|
||
|
|
||
| class VllmInductorPass(InductorPass): | ||
| """ | ||
| An inductor pass with access to vLLM PassConfig. | ||
|
|
@@ -29,7 +49,9 @@ class VllmInductorPass(InductorPass): | |
| """Keep track of pass index for debug dump ordering.""" | ||
|
|
||
| def __init__(self, config: VllmConfig): | ||
| self.compilation_config = weakref.proxy(config.compilation_config) | ||
| self.compilation_config = copy_necessary_config_for_pass( | ||
| config.compilation_config | ||
| ) | ||
| self.pass_config = config.compilation_config.pass_config | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. later pass_config can be also moved.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can utilize it, but it will introduce duplicated attribute in config level, we can think of how to organize these config better in following PR. @zou3519 |
||
| self.model_dtype = config.model_config.dtype if config.model_config else None | ||
| self.device = config.device_config.device if config.device_config else None | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.