Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 7 additions & 50 deletions vllm/compilation/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import Any, ParamSpec, TypeVar

import torch
import torch._C._dynamo.guards

import vllm.envs as envs
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
Expand All @@ -24,65 +23,23 @@
P = ParamSpec("P")


def _noop_add_global_state_guard(
self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
) -> None:
"""No-op to skip the GLOBAL_STATE guard entirely"""
pass


def _noop_add_torch_function_mode_stack_guard(
self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
) -> None:
"""No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely"""
pass


@contextmanager
def _compilation_context() -> Generator[None, None, None]:
"""Context manager for compilation settings and patches.

This manager:
1. Sets higher dynamo cache limits for compilation. (Needed for
qwen2_5_vl see test_qwen2_5_vl_evs_functionality).
Generally a recompilation can happen whenever we use a new
backend instance in torch.compile.
2. Patches out add_global_state_guard to skip GLOBAL_STATE guards
3. Patches out add_torch_function_mode_stack_guard to skip
TORCH_FUNCTION_MODE_STACK guards.
4. Restores everything when compilation completes
"""Context manager for compilation settings.

This manager sets higher dynamo cache limits for compilation.
(Needed for qwen2_5_vl see test_qwen2_5_vl_evs_functionality).
Generally a recompilation can happen whenever we use a new
backend instance in torch.compile.
"""
# Save original values
original_global_state_guard = (
torch._C._dynamo.guards.GuardManager.add_global_state_guard
)
original_torch_function_mode_stack_guard = (
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard
)
original_cache_size = torch._dynamo.config.cache_size_limit
original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit

try:
# Set higher cache limits for compilation
torch._dynamo.config.cache_size_limit = 2048
torch._dynamo.config.accumulated_cache_size_limit = 8192

# Patch guard manager
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
_noop_add_global_state_guard
)
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
_noop_add_torch_function_mode_stack_guard
)
yield
finally:
# Restore original values
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
original_global_state_guard
)
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
original_torch_function_mode_stack_guard
)
torch._dynamo.config.cache_size_limit = original_cache_size
torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache

Expand Down Expand Up @@ -155,7 +112,7 @@ def __init__(self) -> None:
entry.guard_type == "SHAPE_ENV" for entry in x
]
else:
options["guard_filter_fn"] = lambda x: [False for _ in x]
options["guard_filter_fn"] = torch.compiler.skip_all_guards_unsafe

compiled_ptr: Any = self.forward
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
Expand Down
Loading