Skip to content

[FIX_FOR_VLLM_CUSTOM=e31915063da3f6d6be6080040de28f5bb6945acd] Fix GraphCaptureOutput stub and MultiModalDataDict import path#1280

Closed
pawel-olejniczak wants to merge 1 commit intovllm-project:mainfrom
pawel-olejniczak:fix/vllm-hourly-graphcapture-multimodal
Closed

[FIX_FOR_VLLM_CUSTOM=e31915063da3f6d6be6080040de28f5bb6945acd] Fix GraphCaptureOutput stub and MultiModalDataDict import path#1280
pawel-olejniczak wants to merge 1 commit intovllm-project:mainfrom
pawel-olejniczak:fix/vllm-hourly-graphcapture-multimodal

Conversation

@pawel-olejniczak
Copy link
Copy Markdown
Contributor

Summary

Fixes two import errors introduced by recent upstream vLLM changes that break all CI tests on Gaudi HPU.

Bug 1: GraphCaptureOutput ImportError (blocks all tests)

Upstream vLLM PR #37234 (commit e31915063da) added a monkey-patch in env_override.py guarded by not is_torch_equal_or_newer("2.12.0") that imports GraphCaptureOutput from torch._dynamo.convert_frame and patches its get_runtime_env method.

Gaudi's PyTorch build (2.9.0+hpu) cherry-picked the upstream PyTorch fix (pytorch/177558) which:

  • Renamed GraphCaptureOutputCaptureOutput
  • Removed the get_runtime_env method (the class is now empty)

Since Gaudi's torch reports version < 2.12.0, vLLM's guard activates and the import fails.

Fix: Add a _torch_compat.py shim that creates a stub GraphCaptureOutput class with a no-op get_runtime_env method. The stub satisfies env_override.py's import and monkey-patching without error. The patched method is never called at runtime because Gaudi's PyTorch already contains the underlying fix. The shim is loaded:

  • In tests: via tests/conftest.py (before any import vllm)
  • At runtime: via a .pth file installed into site-packages

Bug 2: MultiModalDataDict ImportError (affects deepseek_ocr)

Upstream vLLM PR #35182 (commit ba2f0acc2) moved MultiModalDataDict from vllm.multimodal.inputs to vllm.inputs.

Fix: Update the import path in vllm_gaudi/models/deepseek_ocr.py.

Files Changed

File Change
vllm_gaudi/_torch_compat.py NEW — Torch compat shim with GraphCaptureOutput stub
tests/conftest.py NEW — Root conftest that loads the shim before vLLM imports
vllm_gaudi_torch_compat.pth NEW.pth file for runtime shim loading
setup.py Added data_files for .pth installation
vllm_gaudi/models/deepseek_ocr.py Fixed MultiModalDataDict import path

HPU Verification

Tested on Gaudi3 pod (torch 2.9.0+hpu_1.23.0, Python 3.12):

  • import vllm succeeds (was crashing before)
  • pytest tests/unit_tests/ops/test_hpu_fused_moe.py — 1 passed
  • deepseek_ocr.py import reaches past the fixed line

Jira: Related to hourly CI triage findings.

…aphCaptureOutput alias and MultiModalDataDict import path

Bug 1: Gaudi's custom PyTorch build cherry-picked the rename of
GraphCaptureOutput -> CaptureOutput before bumping to 2.12.  Upstream
vLLM's env_override.py imports GraphCaptureOutput when torch < 2.12,
which fails on Gaudi.  Fix: a _torch_compat shim creates the alias,
loaded via conftest.py (tests) and a .pth file (production).

Bug 2: Upstream vLLM moved MultiModalDataDict from
vllm.multimodal.inputs to vllm.inputs.  Updated deepseek_ocr.py.

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
Copilot AI review requested due to automatic review settings March 31, 2026 10:14
@pawel-olejniczak pawel-olejniczak deleted the fix/vllm-hourly-graphcapture-multimodal branch March 31, 2026 10:14
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR addresses upstream vLLM compatibility regressions that currently break Gaudi CI by adding a Torch startup shim for GraphCaptureOutput and updating the MultiModalDataDict import path for the DeepSeek OCR model integration.

Changes:

  • Add a Torch compatibility shim to provide a stub GraphCaptureOutput for Gaudi’s custom PyTorch builds.
  • Ensure the shim is loaded in tests and (intended) at runtime via a .pth startup hook.
  • Update MultiModalDataDict import in deepseek_ocr.py to match the upstream vLLM module move.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
vllm_gaudi/models/deepseek_ocr.py Update MultiModalDataDict import to vllm.inputs.
vllm_gaudi/_torch_compat.py Add shim that injects torch._dynamo.convert_frame.GraphCaptureOutput when missing.
tests/conftest.py Import shim early in test startup.
vllm_gaudi_torch_compat.pth Add Python startup hook intended to load shim before vllm imports.
setup.py Install the .pth file via data_files.

@@ -0,0 +1 @@
import vllm_gaudi._torch_compat
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The .pth imports vllm_gaudi._torch_compat, which first executes vllm_gaudi/__init__.py. That module imports vllm_gaudi.platform, and platform.py imports vllm (from vllm import envs), so this startup hook can end up importing vllm before the shim runs—defeating the intended ordering and potentially re-triggering the original env_override.py failure.

Consider changing the .pth to import a standalone shim module that does not import the vllm_gaudi package (or refactor vllm_gaudi/__init__.py to avoid importing platform at import-time) so the patch can be applied without pulling in vllm.

Suggested change
import vllm_gaudi._torch_compat
import vllm_gaudi_torch_compat

Copilot uses AI. Check for mistakes.
Comment thread tests/conftest.py
Comment on lines +2 to +6
"""Root-level conftest – ensures torch compatibility shims are applied
before any ``import vllm`` happens during the test session.
"""

import vllm_gaudi._torch_compat # noqa: F401 -- side-effect: patches GraphCaptureOutput alias
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import executes vllm_gaudi/__init__.py before _torch_compat.py runs; __init__.py imports vllm_gaudi.platform which imports vllm (from vllm import envs). That means this conftest does not guarantee the shim is applied before any import vllm, contrary to the module docstring, and may fail in the same way the shim is trying to prevent.

To ensure ordering, import a shim that doesn't import the vllm_gaudi package (or make vllm_gaudi/__init__.py lazy) and then apply the patch from there.

Copilot uses AI. Check for mistakes.
Comment thread setup.py
Comment on lines +65 to +69
data_files=[
# Install a .pth file so the torch compat shim runs at Python startup,
# before ``import vllm`` triggers env_override.py.
(".", ["vllm_gaudi_torch_compat.pth"]),
],
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Installing the .pth via data_files=[('.', ...)] typically places it under the scheme's data directory (often sys.prefix), not the purelib/site-packages directory where Python actually processes .pth files at startup. If it doesn't land in site-packages, the runtime shim won't run and the original import vllm failure will persist.

Please adjust the installation approach so the .pth ends up in site-packages (or use an alternative startup mechanism that is reliably loaded from site-packages).

Suggested change
data_files=[
# Install a .pth file so the torch compat shim runs at Python startup,
# before ``import vllm`` triggers env_override.py.
(".", ["vllm_gaudi_torch_compat.pth"]),
],

Copilot uses AI. Check for mistakes.
return None

_cf.GraphCaptureOutput = _GraphCaptureOutputStub # type: ignore[attr-defined]
except Exception:
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This catches all exceptions and silently ignores them. Since this module is meant to protect imports, swallowing unexpected errors (e.g., API changes, AttributeError, RuntimeError during torch init) can make failures much harder to diagnose.

Prefer catching the specific expected failures (e.g., ImportError / ModuleNotFoundError) and, if you still want to proceed silently, consider at least logging at debug level for unexpected exception types.

Suggested change
except Exception:
except (ImportError, ModuleNotFoundError):

Copilot uses AI. Check for mistakes.
class _GraphCaptureOutputStub:
"""Stub standing in for the removed GraphCaptureOutput class."""

def get_runtime_env(self): # type: ignore[override]
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring says the stub provides a get_runtime_env class-method, but the stub defines it as an instance method. Even if it's "never called", matching the documented/expected callable shape is safer because the upstream monkey-patch may treat it as a classmethod.

Define get_runtime_env as a @classmethod (or otherwise match the upstream signature) to avoid subtle incompatibilities.

Suggested change
def get_runtime_env(self): # type: ignore[override]
@classmethod
def get_runtime_env(cls): # type: ignore[override]

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants