Skip to content

Commit f178780

Browse files
committed
patch 21031
1 parent 3cd3666 commit f178780

File tree

5 files changed

+29
-7
lines changed

5 files changed

+29
-7
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,12 @@ def __init__(self, config: VllmConfig):
431431

432432
self.dump_patterns(config, self.patterns)
433433

434-
def is_applicable_for_shape(self, shape: int | None) -> bool:
435-
# only do replace for specific shapes
434+
def is_applicable(self, shape: int | None) -> bool:
435+
# This pass is applied on top of the sequence parallelism pass.
436+
# It inherits the same applicability condition as `SequenceParallelismPass`.
437+
# See `SequenceParallelismPass.is_applicable` for more details.
438+
if self.splitting_ops is None or self.splitting_ops == []:
439+
return True
436440
tp_size = get_tensor_model_parallel_world_size()
437441
return shape is not None and shape % tp_size == 0
438442

vllm/compilation/inductor_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def hash_dict(dict_: dict[Any, Any]):
9696
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
9797
return hashlib.sha256(encoded).hexdigest()
9898

99-
def is_applicable_for_shape(self, shape: int | None):
99+
def is_applicable(self, shape: int | None):
100100
return True
101101

102102

vllm/compilation/pass_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __call__(self, graph: fx.Graph):
7171

7272
shape = get_pass_context().runtime_shape
7373
for pass_ in self.passes:
74-
if pass_.is_applicable_for_shape(shape):
74+
if pass_.is_applicable(shape):
7575
pass_(graph)
7676
VllmInductorPass.dump_prefix += 1
7777

vllm/compilation/sequence_parallelism.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,22 @@ def __init__(self, config: VllmConfig):
482482
).register(self.patterns)
483483
self.dump_patterns(config, self.patterns)
484484

485-
def is_applicable_for_shape(self, shape: int | None) -> bool:
485+
def is_applicable(self, shape: int | None) -> bool:
486+
# When sequence parallelism is enabled, the residual tensor from RMSNorm
487+
# needs to be split along the sequence dimension. However, this dimension
488+
# is symbolic during piecewise compilation, and splitting symbolic shapes
489+
# is not supported.
490+
#
491+
# This pass is therefore only applied when the sequence dimension is
492+
# concrete:
493+
# 1. In full-graph compilation mode (no splitting ops are used).
494+
# For this case we always pad num_tokens to be a multiple of
495+
# tensor_parallel_size, so there's no need to check shape % tp_size == 0.
496+
# 2. For specific shape provided during compilation (e.g., from
497+
# `compile_sizes`), which must be divisible by the tensor-parallel
498+
# size.
499+
if self.splitting_ops is None or self.splitting_ops == []:
500+
return True
486501
tp_size = get_tensor_model_parallel_world_size()
487502
return shape is not None and shape % tp_size == 0
488503

vllm/compilation/vllm_inductor_pass.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,11 @@ class VllmInductorPass(InductorPass):
2929

3030
def __init__(self, config: VllmConfig):
3131
self.pass_config = config.compilation_config.pass_config
32-
self.model_dtype = config.model_config.dtype if config.model_config else None
33-
self.device = config.device_config.device if config.device_config else None
32+
self.splitting_ops = config.compilation_config.splitting_ops
33+
self.model_dtype = config.model_config.dtype if config.model_config \
34+
else None
35+
self.device = config.device_config.device if config.device_config \
36+
else None
3437
self.pass_name = self.__class__.__name__
3538

3639
@staticmethod

0 commit comments

Comments
 (0)