Skip to content

Commit 09d2444

Browse files
angelayiZhathw
authored andcommitted
[compile] Enable sequence parallelism for full cuda graph without specifying compile sizes (vllm-project#26681)
Signed-off-by: angelayi <[email protected]>
1 parent b026a1c commit 09d2444

File tree

5 files changed

+34
-5
lines changed

5 files changed

+34
-5
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,15 @@ def __init__(self, config: VllmConfig):
431431

432432
self.dump_patterns(config, self.patterns)
433433

434-
def is_applicable_for_shape(self, shape: int | None) -> bool:
435-
# only do replace for specific shapes
434+
def is_applicable(self, shape: int | None) -> bool:
435+
# This pass is applied on top of the sequence parallelism pass.
436+
# It inherits the same applicability condition as `SequenceParallelismPass`.
437+
# See `SequenceParallelismPass.is_applicable` for more details.
438+
if (
439+
not self.compilation_config.splitting_ops
440+
or self.compilation_config.use_inductor_graph_partition
441+
):
442+
return True
436443
tp_size = get_tensor_model_parallel_world_size()
437444
return shape is not None and shape % tp_size == 0
438445

vllm/compilation/inductor_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def hash_dict(dict_: dict[Any, Any]):
9696
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
9797
return hashlib.sha256(encoded).hexdigest()
9898

99-
def is_applicable_for_shape(self, shape: int | None):
99+
def is_applicable(self, shape: int | None):
100100
return True
101101

102102

vllm/compilation/pass_manager.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,11 @@ def __call__(self, graph: fx.Graph):
7171

7272
shape = get_pass_context().runtime_shape
7373
for pass_ in self.passes:
74-
if pass_.is_applicable_for_shape(shape):
74+
if pass_.is_applicable(shape):
7575
pass_(graph)
7676
VllmInductorPass.dump_prefix += 1
77+
else:
78+
logger.debug("Skipping %s with shape %s", pass_, shape)
7779

7880
# post-cleanup goes before fix_functionalization
7981
# because it requires a functional graph

vllm/compilation/sequence_parallelism.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,25 @@ def __init__(self, config: VllmConfig):
482482
).register(self.patterns)
483483
self.dump_patterns(config, self.patterns)
484484

485-
def is_applicable_for_shape(self, shape: int | None) -> bool:
485+
def is_applicable(self, shape: int | None) -> bool:
486+
# When sequence parallelism is enabled, the residual tensor from RMSNorm
487+
# needs to be split along the sequence dimension. However, this dimension
488+
# is symbolic during piecewise compilation, and splitting symbolic shapes
489+
# is not supported.
490+
#
491+
# This pass is therefore only applied when the sequence dimension is
492+
# concrete:
493+
# 1. In full-graph compilation mode (no Dynamo splitting ops are used).
494+
# For this case we always pad num_tokens to be a multiple of
495+
# tensor_parallel_size, so there's no need to check shape % tp_size == 0.
496+
# 2. For specific shape provided during compilation (e.g., from
497+
# `compile_sizes`), which must be divisible by the tensor-parallel
498+
# size.
499+
if (
500+
not self.compilation_config.splitting_ops
501+
or self.compilation_config.use_inductor_graph_partition
502+
):
503+
return True
486504
tp_size = get_tensor_model_parallel_world_size()
487505
return shape is not None and shape % tp_size == 0
488506

vllm/compilation/vllm_inductor_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import functools
44
import operator
55
import time
6+
import weakref
67
from typing import ClassVar
78

89
import regex as re
@@ -28,6 +29,7 @@ class VllmInductorPass(InductorPass):
2829
"""Keep track of pass index for debug dump ordering."""
2930

3031
def __init__(self, config: VllmConfig):
32+
self.compilation_config = weakref.proxy(config.compilation_config)
3133
self.pass_config = config.compilation_config.pass_config
3234
self.model_dtype = config.model_config.dtype if config.model_config else None
3335
self.device = config.device_config.device if config.device_config else None

0 commit comments

Comments
 (0)