Skip to content

Commit

Permalink
[2/N][torch.compile] make compilation cfg part of vllm cfg (vllm-proj…
Browse files Browse the repository at this point in the history
…ect#10383)

Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored and coolkp committed Nov 20, 2024
1 parent 9c9a222 commit 4575dec
Show file tree
Hide file tree
Showing 27 changed files with 359 additions and 283 deletions.
8 changes: 5 additions & 3 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.config import CompilationLevel, VllmConfig
from vllm.plugins import set_current_vllm_config
from vllm.utils import direct_register_custom_op

global_counter = 0
Expand Down Expand Up @@ -82,7 +82,9 @@ def test_simple_piecewise_compile():
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)

model = SillyModel(vllm_config=VllmConfig(), prefix='')
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='')

inputs = torch.randn(100).cuda()

Expand Down
22 changes: 12 additions & 10 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
from torch.library import Library

from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.config import CompilationConfig
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.plugins import set_compilation_config
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.plugins import set_compilation_config, set_current_vllm_config
from vllm.utils import direct_register_custom_op

# create a library to hold the custom op
Expand Down Expand Up @@ -272,9 +270,11 @@ def run_model(llama_config,
CompilationLevel.NO_COMPILATION)
set_compilation_config(None)

model = LlamaModel(config=llama_config,
vllm_config=VllmConfig(),
prefix="").eval().cuda()
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config,
vllm_config=vllm_config,
prefix="").eval().cuda()

B = 16 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
Expand Down Expand Up @@ -395,9 +395,11 @@ def benchmark():
else:
set_compilation_config(None)

model = LlamaModel(config=llama_config,
vllm_config=VllmConfig(),
prefix="").eval().cuda().to(torch.bfloat16)
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config,
vllm_config=vllm_config,
prefix="").eval().cuda().to(torch.bfloat16)

B = 256 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
Expand Down
2 changes: 1 addition & 1 deletion tests/compile/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel
from vllm.utils import cuda_device_count_stateless

from ..utils import compare_all_settings
Expand Down
2 changes: 1 addition & 1 deletion tests/compile/test_full_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel

from ..utils import fork_new_process_for_each_test
from .utils import TEST_MODELS, check_full_graph_support
Expand Down
2 changes: 1 addition & 1 deletion tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from compressed_tensors.quantization import FP8_DTYPE

import vllm.envs as envs
from vllm.compilation.config import CompilationConfig
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
find_auto_fn_maybe)
from vllm.compilation.reshapes import RedundantReshapesPass
from vllm.config import CompilationConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear)
Expand Down
4 changes: 3 additions & 1 deletion tests/compile/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import CompilationLevel


class MyMod(torch.nn.Module):
Expand All @@ -18,7 +19,8 @@ class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
def __init__(self, model):
self.model = model
compiled_callable = torch.compile(self.forward, backend="eager")
super().__init__(compiled_callable)
super().__init__(compiled_callable,
compilation_level=CompilationLevel.DYNAMO_ONCE)

def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
# this is the function to be compiled
Expand Down
2 changes: 1 addition & 1 deletion tests/compile/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel
from vllm.platforms import current_platform

TEST_MODELS = [
Expand Down
52 changes: 26 additions & 26 deletions tests/model_executor/test_enabled_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

import pytest

from vllm.config import CompilationConfig, VllmConfig
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation,
SiluAndMul)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.plugins import set_current_vllm_config


# Registered subclass for test
Expand Down Expand Up @@ -51,42 +53,40 @@ class Relu3(ReLUSquaredActivation):
])
def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
default_on: bool):
os.environ["VLLM_CUSTOM_OPS"] = env
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level)
vllm_config = VllmConfig(compilation_config=CompilationConfig(
custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
assert CustomOp.default_on() == default_on

# Reset default_on (computed once):
CustomOp.default_on.cache_clear()
ops_enabled = [bool(x) for x in ops_enabled]

assert CustomOp.default_on() == default_on
assert RMSNorm(1024).enabled() == ops_enabled[0]
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]

ops_enabled = [bool(x) for x in ops_enabled]
assert SiluAndMul().enabled() == ops_enabled[1]
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]

assert RMSNorm(1024).enabled() == ops_enabled[0]
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
assert GeluAndMul().enabled() == ops_enabled[2]
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]

assert SiluAndMul().enabled() == ops_enabled[1]
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
# If registered, subclasses should follow their own name
assert Relu3().enabled() == ops_enabled[3]
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]

assert GeluAndMul().enabled() == ops_enabled[2]
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
# Unregistered subclass
class SiluAndMul2(SiluAndMul):
pass

# If registered, subclasses should follow their own name
assert Relu3().enabled() == ops_enabled[3]
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]

# Unregistered subclass
class SiluAndMul2(SiluAndMul):
pass

# Subclasses should not require registration
assert SiluAndMul2().enabled() == SiluAndMul().enabled()
# Subclasses should not require registration
assert SiluAndMul2().enabled() == SiluAndMul().enabled()


@pytest.mark.parametrize(
"env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"])
def test_enabled_ops_invalid(env: str):
os.environ["VLLM_CUSTOM_OPS"] = env
CustomOp.default_on.cache_clear()

with pytest.raises(AssertionError):
RMSNorm(1024).enabled()
with pytest.raises(Exception): # noqa
vllm_config = VllmConfig(compilation_config=CompilationConfig(
custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
RMSNorm(1024).enabled()
2 changes: 1 addition & 1 deletion tests/tpu/test_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import depyf

from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel

# disable custom dispatcher, let Dynamo takes over
# all the control
Expand Down
2 changes: 1 addition & 1 deletion tests/tpu/test_custom_dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel

from ..utils import compare_two_settings

Expand Down
20 changes: 13 additions & 7 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@
import torch.fx as fx

import vllm.envs as envs
from vllm.config import CompilationConfig, CompilationLevel
from vllm.logger import init_logger
from vllm.utils import combine_fx_passes, weak_ref_tensors

from .config import CompilationConfig
from .counter import compilation_counter
from .fusion import FusionPass
from .levels import CompilationLevel
from .reshapes import RedundantReshapesPass

logger = init_logger(__name__)
Expand Down Expand Up @@ -392,7 +391,10 @@ class VllmBackend:
sym_tensor_indices: List[int]
input_buffers: List[torch.Tensor]

def __init__(self, post_grad_passes: Sequence[Callable] = ()):
def __init__(
self,
compilation_configs: CompilationConfig,
):
global global_graph_pool
if global_graph_pool is None:
global_graph_pool = torch.cuda.graph_pool_handle()
Expand All @@ -401,11 +403,13 @@ def __init__(self, post_grad_passes: Sequence[Callable] = ()):
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self.graph_pool = global_graph_pool
self.post_grad_passes = post_grad_passes
self.post_grad_passes = []

self.sym_tensor_indices = []
self.input_buffers = []

self.compilation_configs = compilation_configs

# `torch.compile` is JIT compiled, so we don't need to
# do anything here

Expand Down Expand Up @@ -437,10 +441,10 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
assert not self._called, "VllmBackend can only be called once"

self.graph = graph
# config is read now, because only here can
# config is updated now, because only here can
# we get the sizes to capture for cudagraph
# from compilation context
self.compilation_configs = CompilationConfig.select_and_init_config()
self.compilation_configs.init_during_runtime()
self.add_passes_to_config()

self.split_gm, self.piecewise_graphs = split_graph(
Expand Down Expand Up @@ -688,4 +692,6 @@ def select_default_backend(level: int) -> Union[str, Callable]:
return backend_str
assert level == CompilationLevel.PIECEWISE

return VllmBackend()
from vllm.plugins import get_current_vllm_config
compilation_config = get_current_vllm_config().compilation_config
return VllmBackend(compilation_config)
Loading

0 comments on commit 4575dec

Please sign in to comment.