diff --git a/setup.py b/setup.py index 5673a4a47a..bc766de4a7 100644 --- a/setup.py +++ b/setup.py @@ -62,6 +62,11 @@ def _read_requirements(filename: str) -> list[str]: install_requires=get_requirements(), ext_modules=ext_modules, extras_require={}, + data_files=[ + # Install a .pth file so the torch compat shim runs at Python startup, + # before ``import vllm`` triggers env_override.py. + (".", ["vllm_gaudi_torch_compat.pth"]), + ], entry_points={ "vllm.platform_plugins": ["hpu = vllm_gaudi:register"], "vllm.general_plugins": [ diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000..17ef7add89 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Root-level conftest – ensures torch compatibility shims are applied +before any ``import vllm`` happens during the test session. +""" + +import vllm_gaudi._torch_compat # noqa: F401 -- side-effect: patches GraphCaptureOutput alias diff --git a/vllm_gaudi/_torch_compat.py b/vllm_gaudi/_torch_compat.py new file mode 100644 index 0000000000..cc4233f807 --- /dev/null +++ b/vllm_gaudi/_torch_compat.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Torch compatibility shim for Gaudi's custom PyTorch builds. + +Gaudi's PyTorch build (2.9+hpu) cherry-picked the builtins fix from +upstream PyTorch (pytorch/177558), which renamed ``GraphCaptureOutput`` +to ``CaptureOutput`` and removed the ``get_runtime_env`` method. + +vLLM's ``env_override.py`` (guarded by ``not is_torch_equal_or_newer("2.12.0")``) +tries to import ``GraphCaptureOutput`` and patch its ``get_runtime_env``. +On Gaudi's build this block must be skipped because the fix is already applied. + +We inject a stub ``GraphCaptureOutput`` class with a ``get_runtime_env`` +class-method so that ``env_override.py`` can import and "patch" it without +error. The patched method is never actually called because the underlying +PyTorch code already contains the fix. + +This module is loaded: +* In tests – via ``tests/conftest.py`` (runs before any ``import vllm``). +* At runtime – via a ``.pth`` file installed into site-packages so that + the shim is in place before *any* Python code imports ``vllm``. +""" + +try: + import torch._dynamo.convert_frame as _cf + + if not hasattr(_cf, "GraphCaptureOutput"): + # The Gaudi PyTorch build already has the builtins fix applied; + # create a stub so that env_override.py can import and monkey-patch + # it harmlessly. + + class _GraphCaptureOutputStub: + """Stub standing in for the removed GraphCaptureOutput class.""" + + def get_runtime_env(self): # type: ignore[override] + """No-op — the real fix is already in this PyTorch build.""" + return None + + _cf.GraphCaptureOutput = _GraphCaptureOutputStub # type: ignore[attr-defined] +except Exception: + # If torch._dynamo.convert_frame is unavailable, there is nothing + # to patch – silently continue. + pass diff --git a/vllm_gaudi/models/deepseek_ocr.py b/vllm_gaudi/models/deepseek_ocr.py index 13eecae107..1205167abd 100644 --- a/vllm_gaudi/models/deepseek_ocr.py +++ b/vllm_gaudi/models/deepseek_ocr.py @@ -10,7 +10,7 @@ from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalDataDict +from vllm.inputs import MultiModalDataDict from vllm.model_executor.models.deepseek_ocr import ( DeepseekOCRForCausalLM, DeepseekOCRMultiModalProcessor, diff --git a/vllm_gaudi_torch_compat.pth b/vllm_gaudi_torch_compat.pth new file mode 100644 index 0000000000..b03d1f248a --- /dev/null +++ b/vllm_gaudi_torch_compat.pth @@ -0,0 +1 @@ +import vllm_gaudi._torch_compat