Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/compile/test_async_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,15 @@ def async_tp_pass_on_test_model(
async_tp_pass = AsyncTPPass(vllm_config)
backend = TestBackend(async_tp_pass)

assert (
async_tp_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)
assert (
async_tp_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)

model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor

hidden_states = torch.randn(
Expand Down
17 changes: 17 additions & 0 deletions tests/compile/test_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy

import pytest

from vllm.compilation.counter import compilation_counter
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import CompilationMode
from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
Expand All @@ -25,6 +28,20 @@ def test_use_cudagraphs_dynamic():
assert vllm_config.compilation_config.use_cudagraph


def test_copy_pass():
vllm_config = VllmConfig()
inductor_pass = FixFunctionalizationPass(vllm_config)
copied_inductor_pass = copy.deepcopy(inductor_pass)
assert (
copied_inductor_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
assert (
copied_inductor_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)


def test_custom_op():
# proper syntax
_ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])
Expand Down
8 changes: 8 additions & 0 deletions tests/compile/test_sequence_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,14 @@ def sequence_parallelism_pass_on_test_model(

noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
assert (
sequence_parallelism_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)
assert (
sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)

Expand Down
17 changes: 15 additions & 2 deletions vllm/compilation/vllm_inductor_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import functools
import operator
import time
import weakref
from dataclasses import dataclass
from typing import ClassVar

import regex as re
Expand All @@ -19,6 +19,13 @@
logger = init_logger(__name__)


@dataclass
class InductorCompilationConfig:
splitting_ops: list[str] | None = None
use_inductor_graph_partition: bool = False
compile_sizes: list[int | str] | None = None


class VllmInductorPass(InductorPass):
"""
An inductor pass with access to vLLM PassConfig.
Expand All @@ -29,7 +36,13 @@ class VllmInductorPass(InductorPass):
"""Keep track of pass index for debug dump ordering."""

def __init__(self, config: VllmConfig):
self.compilation_config = weakref.proxy(config.compilation_config)
# Get only the necessary CompilationConfig for the inductor pass, since
# full `CompilationConfig` contains pointer to model which is unsafe.
self.compilation_config = InductorCompilationConfig(
splitting_ops=config.splitting_ops,
use_inductor_graph_partition=config.use_inductor_graph_partition,
compile_sizes=config.compile_sizes,
)
self.pass_config = config.compilation_config.pass_config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

later pass_config can be also moved.

Copy link
Collaborator Author

@luccafong luccafong Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can utilize it, but it will introduce duplicated attribute in config level, we can think of how to organize these config better in following PR. @zou3519

self.model_dtype = config.model_config.dtype if config.model_config else None
self.device = config.device_config.device if config.device_config else None
Expand Down