Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 52 additions & 4 deletions vllm_gaudi/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@
requires the device's allocator to be a ``c10::DeviceAllocator``. We
replace it with an HPU-safe variant that uses
``current_platform.empty_cache()`` instead (see GAUDISW-247825).

* ``vllm.v1.sample.ops.logprobs.batched_count_greater_than`` — upstream
decorates this function with ``@torch.compile(dynamic=True, ...)``.
Habana's ``recipe_compiler`` backend cannot handle the symbolic shapes
produced by ``dynamic=True`` (and by ``mark_unbacked`` in the caller),
raising ``TypeError: Cannot convert symbols to int``. We replace it
with a plain (uncompiled) version of the same function. The replacement
is deferred to ``load_general_plugins`` time to avoid importing
``vllm.v1.sample.sampler`` during early plugin registration, which would
trigger a heavy import chain that interferes with platform initialisation.
"""

import functools
Expand Down Expand Up @@ -74,6 +84,31 @@ def _hpu_cleanup_dist_env_and_memory(shutdown_ray: bool = False) -> None:
parallel_state.logger.warning("torch._C._host_emptyCache() only available in Pytorch >=2.5")


def _hpu_batched_count_greater_than(x: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
"""HPU-safe replacement for ``batched_count_greater_than``.

Identical logic to the upstream implementation but *not* wrapped in
``torch.compile``. The upstream decorator uses ``dynamic=True`` whose
symbolic shapes are incompatible with Habana's ``recipe_compiler``
backend, and ``mark_unbacked`` in the caller prevents ``dynamic=False``
from helping.
"""
return (x >= values).sum(-1)


def _patch_batched_count_greater_than() -> None:
"""Replace ``batched_count_greater_than`` in the sampler & logprobs modules.

Called from the ``load_general_plugins`` hook so that the heavy
``vllm.v1.sample.*`` import chain runs *after* platform initialisation.
"""
import vllm.v1.sample.ops.logprobs as _logprobs_mod
import vllm.v1.sample.sampler as _sampler_mod
Comment on lines +105 to +106

_logprobs_mod.batched_count_greater_than = _hpu_batched_count_greater_than
_sampler_mod.batched_count_greater_than = _hpu_batched_count_greater_than


def apply() -> None:
"""Install all HPU runtime monkey-patches."""
# --- torch.accelerator.empty_cache ---
Expand All @@ -83,15 +118,28 @@ def apply() -> None:
if not hasattr(torch._C, "_host_emptyCache"):
torch._C._host_emptyCache = lambda: None

# Patch the canonical definition.
# --- cleanup_dist_env_and_memory ---
parallel_state.cleanup_dist_env_and_memory = _hpu_cleanup_dist_env_and_memory
# Patch the re-export from ``vllm.distributed`` so ``from vllm.distributed
# import cleanup_dist_env_and_memory`` (used by the upstream pytest
# conftest) also picks up the HPU-safe version.
import vllm.distributed as _vllm_distributed

_vllm_distributed.cleanup_dist_env_and_memory = _hpu_cleanup_dist_env_and_memory

# --- batched_count_greater_than (deferred) ---
# We cannot import the sampler modules here because the import chain
# triggers platform re-initialisation ("Device string must not be
# empty"). Instead we hook into ``load_general_plugins`` which runs
# in every process (parent + EngineCore subprocess) after the platform
# is ready.
import vllm.plugins as _plugins_mod

_original_load_general = _plugins_mod.load_general_plugins

def _load_general_with_hpu_patches():
_original_load_general()
_patch_batched_count_greater_than()
Comment on lines +137 to +139

_plugins_mod.load_general_plugins = _load_general_with_hpu_patches


def patch_hf3fs_mock_client():
"""Guard CUDA sync in the HF3FS mock client on non-CUDA platforms.
Expand Down
Loading