diff --git a/tools/pre_commit/check_forbidden_imports.py b/tools/pre_commit/check_forbidden_imports.py index ac7d8b096ec4..365b2f5bb771 100644 --- a/tools/pre_commit/check_forbidden_imports.py +++ b/tools/pre_commit/check_forbidden_imports.py @@ -31,6 +31,7 @@ class ForbiddenImport: "vllm/transformers_utils/config.py", "vllm/model_executor/models/registry.py", "vllm/compilation/caching.py", + "vllm/env_override.py", "vllm/compilation/piecewise_backend.py", "vllm/distributed/utils.py", "vllm/distributed/parallel_state.py", diff --git a/vllm/env_override.py b/vllm/env_override.py index 5358568fc180..432300f4cf55 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -87,7 +87,7 @@ def _maybe_set_cuda_compatibility_path(): import torch from vllm.logger import init_logger -from vllm.utils.torch_utils import is_torch_equal +from vllm.utils.torch_utils import is_torch_equal, is_torch_equal_or_newer logger = init_logger(__name__) @@ -490,3 +490,45 @@ def _patch_get_raw_stream_if_needed(): PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched GraphLowering._update_scheduler = _update_scheduler_patched + +# =================================================== +# torch <2.12 GraphCaptureOutput.get_runtime_env monkeypatch +# =================================================== +# PyTorch's AOT compile path omits builtins from used_globals, causing +# 'Missing required external references' errors for refs like 'type'. +# (which happens in transformers code) +# This mirrors the fix in https://github.com/pytorch/pytorch/pull/177558 +# and can be removed once torch >=2.12 is the minimum supported version. + +if not is_torch_equal_or_newer("2.12.0"): + import builtins as _builtins + import pickle + + from torch._dynamo.convert_frame import GraphCaptureOutput + + _original_get_runtime_env = GraphCaptureOutput.get_runtime_env + + def _safe_builtins_dict(builtins_dict: dict) -> dict: + """Filter a builtins dict to only picklable entries for serialization.""" + result = {} + for k, v in builtins_dict.items(): + try: + pickle.dumps(v) + result[k] = v + except Exception: + pass + return result + + def _patched_get_runtime_env(self): # type: ignore[no-untyped-def] + runtime_env = _original_get_runtime_env(self) + for ref in runtime_env.external_refs: + if ref not in runtime_env.used_globals: + if ref.startswith("__builtins_dict__") and ref in self.f_globals: + runtime_env.used_globals[ref] = _safe_builtins_dict( + self.f_globals[ref] + ) + elif hasattr(_builtins, ref): + runtime_env.used_globals[ref] = getattr(_builtins, ref) + return runtime_env + + GraphCaptureOutput.get_runtime_env = _patched_get_runtime_env