Skip to content

Commit 0c824fc

Browse files
morrison-turnanskyProExpertProgZJY0516
authored
[Frontend] CompilationConfig overhaul (#20283): deprecate use_inductor in favor of backend, simplify custom_ops (#26113)
Signed-off-by: morrison-turnansky <[email protected]> Signed-off-by: Morrison Turnansky <[email protected]> Signed-off-by: Luka Govedič <[email protected]> Co-authored-by: Luka Govedič <[email protected]> Co-authored-by: Jiangyun Zhu <[email protected]>
1 parent eb577e4 commit 0c824fc

File tree

7 files changed

+126
-63
lines changed

7 files changed

+126
-63
lines changed

tests/compile/piecewise/test_toy_llama.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -258,13 +258,13 @@ def tractable_computation(
258258

259259
@torch.inference_mode
260260
def run_model(
261-
llama_config, use_compile: bool, use_inductor: bool, split_attn: bool = False
261+
llama_config, use_compile: bool, backend: str, split_attn: bool = False
262262
) -> torch.Tensor:
263263
if use_compile:
264264
compilation_config = CompilationConfig(
265265
level=CompilationLevel.PIECEWISE,
266266
use_cudagraph=True,
267-
use_inductor=use_inductor,
267+
backend=backend,
268268
cudagraph_capture_sizes=[1, 2],
269269
)
270270
if split_attn:
@@ -338,8 +338,8 @@ def run_model(
338338
return output.cpu()
339339

340340

341-
@pytest.mark.parametrize("use_inductor", [True, False])
342-
def test_toy_llama(use_inductor: bool):
341+
@pytest.mark.parametrize("backend", ["inductor", "eager"])
342+
def test_toy_llama(backend: str):
343343
# compare output with and without piecewise compilation
344344

345345
llama_config = LlamaConfig(
@@ -358,10 +358,10 @@ def test_toy_llama(use_inductor: bool):
358358
num_backend_compilations=0,
359359
num_cudagraph_captured=0,
360360
):
361-
outputs.append(run_model(llama_config, use_inductor=False, use_compile=False))
362-
run_model(tractable_config, use_inductor=False, use_compile=False)
361+
outputs.append(run_model(llama_config, backend="eager", use_compile=False))
362+
run_model(tractable_config, backend="eager", use_compile=False)
363363

364-
if use_inductor:
364+
if backend == "inductor":
365365
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
366366
else:
367367
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
@@ -377,10 +377,8 @@ def test_toy_llama(use_inductor: bool):
377377
num_cudagraph_captured=2,
378378
**kwargs,
379379
):
380-
outputs.append(
381-
run_model(llama_config, use_inductor=use_inductor, use_compile=True)
382-
)
383-
run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
380+
outputs.append(run_model(llama_config, backend=backend, use_compile=True))
381+
run_model(tractable_config, backend=backend, use_compile=True)
384382

385383
with compilation_counter.expect(
386384
num_graphs_seen=1, # one graph for the model
@@ -395,16 +393,9 @@ def test_toy_llama(use_inductor: bool):
395393
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
396394
):
397395
outputs.append(
398-
run_model(
399-
llama_config,
400-
use_inductor=use_inductor,
401-
use_compile=True,
402-
split_attn=True,
403-
)
396+
run_model(llama_config, backend=backend, use_compile=True, split_attn=True)
404397
)
405-
run_model(
406-
tractable_config, use_inductor=use_inductor, use_compile=True, split_attn=True
407-
)
398+
run_model(tractable_config, backend=backend, use_compile=True, split_attn=True)
408399

409400
for i in range(1, len(outputs)):
410401
assert torch.allclose(outputs[0], outputs[i])

tests/model_executor/test_enabled_custom_ops.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,57 +37,59 @@ class Relu3(ReLUSquaredActivation):
3737

3838

3939
@pytest.mark.parametrize(
40-
"env, torch_level, use_inductor, ops_enabled, default_on",
40+
"env, torch_level, backend, ops_enabled, default_on",
4141
[
4242
# Default values based on compile level
4343
# - All by default (no Inductor compilation)
44-
(None, 0, False, [True] * 4, True),
45-
(None, 1, True, [True] * 4, True),
46-
(None, 2, False, [True] * 4, True),
44+
(None, 0, "eager", [True] * 4, True),
45+
(None, 1, "eager", [True] * 4, True),
46+
(None, 2, "eager", [True] * 4, True),
47+
(None, 3, "eager", [True] * 4, True),
4748
# - None by default (with Inductor)
48-
(None, 3, True, [False] * 4, False),
49-
(None, 4, True, [False] * 4, False),
50-
# - All by default (without Inductor)
51-
(None, 3, False, [True] * 4, True),
52-
(None, 4, False, [True] * 4, True),
49+
(None, 0, "inductor", [True] * 4, True),
50+
# - None by default (with Inductor)
51+
(None, 1, "inductor", [False] * 4, False),
52+
(None, 2, "inductor", [False] * 4, False),
53+
(None, 3, "inductor", [False] * 4, False),
5354
# Explicitly enabling/disabling
5455
#
5556
# Default: all
5657
#
5758
# All but SiluAndMul
58-
("+rms_norm,-silu_and_mul", 0, True, [1, 0, 1, 1], True),
59+
("+rms_norm,-silu_and_mul", 0, "inductor", [1, 0, 1, 1], True),
5960
# Only ReLU3
60-
("none,-rms_norm,+relu3", 1, False, [0, 0, 0, 1], False),
61+
("none,-rms_norm,+relu3", 1, "eager", [0, 0, 0, 1], False),
6162
# All but SiluAndMul
62-
("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True),
63+
("all,-silu_and_mul", 2, "inductor", [1, 0, 1, 1], True),
6364
# All but ReLU3 (even if ReLU2 is on)
64-
("-relu3,+relu2", 3, False, [1, 1, 1, 0], True),
65+
("-relu3,+relu2", 3, "eager", [1, 1, 1, 0], True),
6566
# RMSNorm and SiluAndMul
66-
("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False),
67+
("none,-relu3,+rms_norm,+silu_and_mul", 3, "eager", [1, 1, 0, 0], False),
6768
# All but RMSNorm
68-
("-rms_norm", 3, False, [0, 1, 1, 1], True),
69+
("-rms_norm", 3, "eager", [0, 1, 1, 1], True),
6970
#
7071
# Default: none
7172
#
7273
# Only ReLU3
73-
("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False),
74+
("none,+relu3", 3, "inductor", [0, 0, 0, 1], False),
7475
# All but RMSNorm
75-
("all,-rms_norm", 4, True, [0, 1, 1, 1], True),
76+
("all,-rms_norm", 3, "inductor", [0, 1, 1, 1], True),
7677
],
7778
)
7879
def test_enabled_ops(
7980
env: Optional[str],
8081
torch_level: int,
81-
use_inductor: bool,
82+
backend: str,
8283
ops_enabled: list[int],
8384
default_on: bool,
8485
):
8586
custom_ops = env.split(",") if env else []
8687
vllm_config = VllmConfig(
8788
compilation_config=CompilationConfig(
88-
use_inductor=bool(use_inductor), level=torch_level, custom_ops=custom_ops
89+
backend=backend, level=torch_level, custom_ops=custom_ops
8990
)
9091
)
92+
# breakpoint()
9193
with set_current_vllm_config(vllm_config):
9294
assert CustomOp.default_on() == default_on
9395

vllm/compilation/backends.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535

3636
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
37-
if compilation_config.use_inductor:
37+
if compilation_config.backend == "inductor":
3838
# Use standalone compile only if requested, version is new enough,
3939
# and the symbol actually exists in this PyTorch build.
4040
if (
@@ -48,6 +48,10 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
4848
logger.debug("Using InductorAdaptor")
4949
return InductorAdaptor()
5050
else:
51+
assert compilation_config.backend == "eager", (
52+
"Custom backends not supported with CompilationLevel.PIECEWISE"
53+
)
54+
5155
logger.debug("Using EagerAdaptor")
5256
return EagerAdaptor()
5357

vllm/config/compilation.py

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,11 @@ class CompilationConfig:
180180
"""The directory to store the compiled graph, to accelerate Inductor
181181
compilation. By default, it will use model-related information to generate
182182
a cache directory."""
183-
backend: str = ""
183+
backend: str = "inductor"
184184
"""The backend for compilation. It needs to be a string:
185185
186-
- "" (empty string): use the default backend.
186+
- "" (empty string): use the default backend ("inductor" on CUDA-alike
187+
platforms).
187188
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
188189
- "full.module.name": a qualified name which can be used to import the
189190
@@ -192,7 +193,11 @@ class CompilationConfig:
192193
distributed setting. When the compilation level is 1 or 2, the backend is
193194
used for the compilation directly (it sees the whole graph). When the
194195
compilation level is 3, the backend is used for the piecewise compilation
195-
(it sees a part of the graph)."""
196+
(it sees a part of the graph). The backend can not be custom for compilation
197+
level 3. Furthermore, compilation is only piecewise if splitting ops is set
198+
accordingly and use_inductor_cudagraphs_partition is off. Note that the
199+
default options for splitting ops are sufficient for piecewise compilation.
200+
"""
196201
custom_ops: list[str] = field(default_factory=list)
197202
"""Fine-grained control over which custom ops to enable/disable. Use 'all'
198203
to enable all, 'none' to disable all. Also specify a list of custom op
@@ -210,16 +215,24 @@ class CompilationConfig:
210215
compilation."""
211216

212217
# Inductor capture
213-
use_inductor: bool = True
214-
"""Whether to use inductor compilation:
218+
use_inductor: Optional[bool] = None
219+
"""
220+
Whether to use inductor compilation.
221+
222+
This flag is deprecated and will be removed.
223+
Please use the 'backend' option instead.
215224
216225
- False: inductor compilation is not used. graph runs in eager
217226
(custom_ops enabled by default).
218227
- True: inductor compilation is used (custom_ops disabled by default).
219228
One graph for symbolic shape and one graph per size in compile_sizes
220229
are compiled using configurations in inductor_compile_config.
221230
222-
This setting is ignored if level<PIECEWISE."""
231+
This setting is ignored if level<PIECEWISE.
232+
233+
For future compatibility:
234+
If use_inductor is True, backend="inductor" otherwise backend="eager".
235+
"""
223236
compile_sizes: Optional[list[Union[int, str]]] = None
224237
"""Sizes to compile for inductor. In addition
225238
to integers, it also supports "cudagraph_capture_sizes" to
@@ -523,23 +536,59 @@ def __post_init__(self) -> None:
523536
"(where 'op' is the registered op name)"
524537
)
525538

539+
# Currently only eager and inductor backend are supported.
540+
# for piecewise compilation. Custom backends are not suppported for
541+
# piecewise compilation. Update when more backends are supported.
542+
if self.level == CompilationLevel.PIECEWISE and self.backend not in [
543+
"",
544+
"eager",
545+
"inductor",
546+
]:
547+
raise ValueError(
548+
f"Invalid backend for piecewise compilation: {self.backend}"
549+
)
550+
551+
if self.use_inductor is not None:
552+
logger.warning_once(
553+
"The 'use_inductor' flag is deprecated and will be\
554+
removed in a future release."
555+
"Please use the 'backend' option instead.",
556+
)
557+
self.backend = "inductor" if self.use_inductor else "eager"
558+
559+
if self.backend == "":
560+
self.backend = "inductor"
561+
526562
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
563+
"""
564+
Initialize the backend for the compilation config from a vllm config.
565+
Arguments:
566+
vllm_config: The vllm config to initialize the backend from.
567+
Returns:
568+
The backend for the compilation config.
569+
"""
570+
if self.level is None:
571+
raise ValueError(
572+
"No compilation level is set. This method should only be \
573+
called via vllm config where the level is set if none is \
574+
provided."
575+
)
527576
if self.level == CompilationLevel.NO_COMPILATION:
528577
raise ValueError("No compilation level is set.")
529578

530579
from torch._dynamo.backends.registry import list_backends
531580

532581
torch_backends = list_backends(exclude_tags=tuple())
533582
if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
534-
if self.backend == "":
535-
return "eager"
536583
if self.backend in torch_backends:
537584
return self.backend
538585
return resolve_obj_by_qualname(self.backend)
539586

540-
# TODO: pass user-specified backend to piecewise compilation
541-
# merge with the config use_inductor
542587
assert self.level == CompilationLevel.PIECEWISE
588+
if self.backend not in ["eager", "inductor"]:
589+
raise ValueError(
590+
f"Invalid backend for piecewise compilation: {self.backend}"
591+
)
543592

544593
from vllm.compilation.backends import VllmBackend
545594

@@ -692,7 +741,7 @@ def is_attention_compiled_piecewise(self) -> bool:
692741
)
693742

694743
inductor_used = (
695-
self.level == CompilationLevel.PIECEWISE and self.use_inductor
744+
self.level == CompilationLevel.PIECEWISE and self.backend == "inductor"
696745
) or (
697746
self.level >= CompilationLevel.DYNAMO_AS_IS and self.backend == "inductor"
698747
)

vllm/config/vllm.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,25 @@ def __post_init__(self):
318318
# NB: Passing both --enforce-eager and a compilation level
319319
# in V0 means the compilation level wins out.
320320
self.compilation_config.level = CompilationLevel.NO_COMPILATION
321+
else:
322+
assert self.compilation_config.level >= CompilationLevel.NO_COMPILATION
323+
assert self.compilation_config.level <= CompilationLevel.PIECEWISE
324+
assert self.compilation_config.level <= 3
325+
326+
# If user does not set custom ops via none or all set it here based on
327+
# compilation level and backend.
328+
if (
329+
self.compilation_config.custom_ops.count("none")
330+
+ self.compilation_config.custom_ops.count("all")
331+
== 0
332+
):
333+
if (
334+
self.compilation_config.level > 0
335+
and self.compilation_config.backend != "eager"
336+
):
337+
self.compilation_config.custom_ops.append("none")
338+
else:
339+
self.compilation_config.custom_ops.append("all")
321340

322341
# async tp is built on top of sequence parallelism
323342
# and requires it to be enabled.

vllm/model_executor/custom_op.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ def enabled(cls) -> bool:
114114
custom_ops = compilation_config.custom_ops
115115
if not hasattr(cls, "name"):
116116
logger.warning_once(
117-
"Custom op %s was not registered, which means it won't appear in the op registry. It will be enabled/disabled based on the global settings.", # noqa: E501
117+
"Custom op %s was not registered, which means it won't appear\
118+
in the op registry. It will be enabled/disabled based on the\
119+
global settings.", # noqa: E501
118120
cls.__name__,
119121
)
120122
return CustomOp.default_on()
@@ -128,19 +130,17 @@ def enabled(cls) -> bool:
128130
@staticmethod
129131
def default_on() -> bool:
130132
"""
131-
On by default if PyTorch Inductor is not used.
132-
Specifying 'all' or 'none' in custom_op takes precedence.
133+
Behavior controlled by `CompilationConfig.custom_ops`: On by default if
134+
'all', off by default if 'none'.
135+
When PyTorch Inductor is used, 'none' is the default value,
136+
otherwise 'all'.
133137
"""
134-
from vllm.config import CompilationLevel
135-
136138
compilation_config = get_cached_compilation_config()
137-
default_on = (
138-
compilation_config.level < CompilationLevel.PIECEWISE
139-
or not compilation_config.use_inductor
140-
)
141139
count_none = compilation_config.custom_ops.count("none")
142140
count_all = compilation_config.custom_ops.count("all")
143-
return default_on and not count_none > 0 or count_all > 0
141+
assert count_none + count_all == 1
142+
143+
return not count_none > 0 or count_all > 0
144144

145145
# Dictionary of all custom ops (classes, indexed by registered name).
146146
# To check if an op with a name is enabled, call .enabled() on the class.

vllm/platforms/cpu.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
274274
"epilogue_fusion": True,
275275
}
276276
)
277-
if compilation_config.use_inductor:
278-
compilation_config.custom_ops = ["none"]
279277

280278
if vllm_config.lora_config is not None:
281279
compilation_config.level = CompilationLevel.NO_COMPILATION

0 commit comments

Comments
 (0)