diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index c6f6072bdfc4..ce85bae5389d 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -10,7 +10,6 @@ from typing import Any, ParamSpec, TypeVar import torch -import torch._C._dynamo.guards import vllm.envs as envs from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config @@ -24,65 +23,23 @@ P = ParamSpec("P") -def _noop_add_global_state_guard( - self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any -) -> None: - """No-op to skip the GLOBAL_STATE guard entirely""" - pass - - -def _noop_add_torch_function_mode_stack_guard( - self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any -) -> None: - """No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely""" - pass - - @contextmanager def _compilation_context() -> Generator[None, None, None]: - """Context manager for compilation settings and patches. - - This manager: - 1. Sets higher dynamo cache limits for compilation. (Needed for - qwen2_5_vl see test_qwen2_5_vl_evs_functionality). - Generally a recompilation can happen whenever we use a new - backend instance in torch.compile. - 2. Patches out add_global_state_guard to skip GLOBAL_STATE guards - 3. Patches out add_torch_function_mode_stack_guard to skip - TORCH_FUNCTION_MODE_STACK guards. - 4. Restores everything when compilation completes + """Context manager for compilation settings. + + This manager sets higher dynamo cache limits for compilation. + (Needed for qwen2_5_vl see test_qwen2_5_vl_evs_functionality). + Generally a recompilation can happen whenever we use a new + backend instance in torch.compile. """ - # Save original values - original_global_state_guard = ( - torch._C._dynamo.guards.GuardManager.add_global_state_guard - ) - original_torch_function_mode_stack_guard = ( - torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard - ) original_cache_size = torch._dynamo.config.cache_size_limit original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit try: - # Set higher cache limits for compilation torch._dynamo.config.cache_size_limit = 2048 torch._dynamo.config.accumulated_cache_size_limit = 8192 - - # Patch guard manager - torch._C._dynamo.guards.GuardManager.add_global_state_guard = ( - _noop_add_global_state_guard - ) - torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = ( - _noop_add_torch_function_mode_stack_guard - ) yield finally: - # Restore original values - torch._C._dynamo.guards.GuardManager.add_global_state_guard = ( - original_global_state_guard - ) - torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = ( - original_torch_function_mode_stack_guard - ) torch._dynamo.config.cache_size_limit = original_cache_size torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache @@ -155,7 +112,7 @@ def __init__(self) -> None: entry.guard_type == "SHAPE_ENV" for entry in x ] else: - options["guard_filter_fn"] = lambda x: [False for _ in x] + options["guard_filter_fn"] = torch.compiler.skip_all_guards_unsafe compiled_ptr: Any = self.forward # Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False