Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/compile/test_async_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,15 @@ def async_tp_pass_on_test_model(
async_tp_pass = AsyncTPPass(vllm_config)
backend = TestBackend(async_tp_pass)

assert (
async_tp_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)
assert (
async_tp_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)

model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor

hidden_states = torch.randn(
Expand Down
17 changes: 17 additions & 0 deletions tests/compile/test_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy

import pytest

from vllm.compilation.counter import compilation_counter
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import CompilationMode
from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
Expand All @@ -25,6 +28,20 @@ def test_use_cudagraphs_dynamic():
assert vllm_config.compilation_config.use_cudagraph


def test_copy_pass():
vllm_config = VllmConfig()
inductor_pass = FixFunctionalizationPass(vllm_config)
copied_inductor_pass = copy.deepcopy(inductor_pass)
assert (
copied_inductor_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
assert (
copied_inductor_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)


def test_custom_op():
# proper syntax
_ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])
Expand Down
8 changes: 8 additions & 0 deletions tests/compile/test_sequence_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,14 @@ def sequence_parallelism_pass_on_test_model(

noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
assert (
sequence_parallelism_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)
assert (
sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)

Expand Down
15 changes: 13 additions & 2 deletions vllm/compilation/vllm_inductor_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import functools
import operator
import time
import weakref
from dataclasses import dataclass
from typing import ClassVar

import regex as re
Expand All @@ -19,6 +19,12 @@
logger = init_logger(__name__)


@dataclass
class InductorCompilationConfig:
splitting_ops: list[str] | None = None
use_inductor_graph_partition: bool = False


class VllmInductorPass(InductorPass):
"""
An inductor pass with access to vLLM PassConfig.
Expand All @@ -29,7 +35,12 @@ class VllmInductorPass(InductorPass):
"""Keep track of pass index for debug dump ordering."""

def __init__(self, config: VllmConfig):
self.compilation_config = weakref.proxy(config.compilation_config)
# Get only the necessary CompilationConfig for the inductor pass, since
# full `CompilationConfig` contains pointer to model which is unsafe.
self.compilation_config = InductorCompilationConfig(
splitting_ops=config.compilation_config.splitting_ops,
use_inductor_graph_partition=config.compilation_config.use_inductor_graph_partition,
)
self.pass_config = config.compilation_config.pass_config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

later pass_config can be also moved.

Copy link
Collaborator Author

@luccafong luccafong Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can utilize it, but it will introduce duplicated attribute in config level, we can think of how to organize these config better in following PR. @zou3519

self.model_dtype = config.model_config.dtype if config.model_config else None
self.device = config.device_config.device if config.device_config else None
Expand Down