Skip to content

Commit 3fe31bf

Browse files
committed
[compile] Turn on TP/SP when use_inductor_graph_partition=True
Signed-off-by: angelayi <[email protected]>
1 parent 95e66d7 commit 3fe31bf

File tree

4 files changed

+17
-8
lines changed

4 files changed

+17
-8
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,11 @@ def is_applicable(self, shape: int | None) -> bool:
435435
# This pass is applied on top of the sequence parallelism pass.
436436
# It inherits the same applicability condition as `SequenceParallelismPass`.
437437
# See `SequenceParallelismPass.is_applicable` for more details.
438-
if self.splitting_ops is None or self.splitting_ops == []:
438+
splitting_ops = self.compilation_config.splitting_ops # type: ignore[attr-defined]
439+
use_inductor_graph_partition = (
440+
self.compilation_config.use_inductor_graph_partition # type: ignore[attr-defined]
441+
)
442+
if not splitting_ops or use_inductor_graph_partition:
439443
return True
440444
tp_size = get_tensor_model_parallel_world_size()
441445
return shape is not None and shape % tp_size == 0

vllm/compilation/pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def __call__(self, graph: fx.Graph):
7474
if pass_.is_applicable(shape):
7575
pass_(graph)
7676
VllmInductorPass.dump_prefix += 1
77+
else:
78+
logger.debug("Skipping %s with shape %s", pass_, shape)
7779

7880
# post-cleanup goes before fix_functionalization
7981
# because it requires a functional graph

vllm/compilation/sequence_parallelism.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,13 +490,17 @@ def is_applicable(self, shape: int | None) -> bool:
490490
#
491491
# This pass is therefore only applied when the sequence dimension is
492492
# concrete:
493-
# 1. In full-graph compilation mode (no splitting ops are used).
493+
# 1. In full-graph compilation mode (no Dynamo splitting ops are used).
494494
# For this case we always pad num_tokens to be a multiple of
495495
# tensor_parallel_size, so there's no need to check shape % tp_size == 0.
496496
# 2. For specific shape provided during compilation (e.g., from
497497
# `compile_sizes`), which must be divisible by the tensor-parallel
498498
# size.
499-
if self.splitting_ops is None or self.splitting_ops == []:
499+
splitting_ops = self.compilation_config.splitting_ops # type: ignore[attr-defined]
500+
use_inductor_graph_partition = (
501+
self.compilation_config.use_inductor_graph_partition # type: ignore[attr-defined]
502+
)
503+
if not splitting_ops or use_inductor_graph_partition:
500504
return True
501505
tp_size = get_tensor_model_parallel_world_size()
502506
return shape is not None and shape % tp_size == 0

vllm/compilation/vllm_inductor_pass.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import functools
44
import operator
55
import time
6+
import weakref
67
from typing import ClassVar
78

89
import regex as re
@@ -28,12 +29,10 @@ class VllmInductorPass(InductorPass):
2829
"""Keep track of pass index for debug dump ordering."""
2930

3031
def __init__(self, config: VllmConfig):
32+
self.compilation_config = weakref.ref(config.compilation_config)
3133
self.pass_config = config.compilation_config.pass_config
32-
self.splitting_ops = config.compilation_config.splitting_ops
33-
self.model_dtype = config.model_config.dtype if config.model_config \
34-
else None
35-
self.device = config.device_config.device if config.device_config \
36-
else None
34+
self.model_dtype = config.model_config.dtype if config.model_config else None
35+
self.device = config.device_config.device if config.device_config else None
3736
self.pass_name = self.__class__.__name__
3837

3938
@staticmethod

0 commit comments

Comments
 (0)