diff --git a/vllm_gaudi/patches.py b/vllm_gaudi/patches.py index 4a784dd26d..ac3931eb73 100644 --- a/vllm_gaudi/patches.py +++ b/vllm_gaudi/patches.py @@ -17,6 +17,16 @@ requires the device's allocator to be a ``c10::DeviceAllocator``. We replace it with an HPU-safe variant that uses ``current_platform.empty_cache()`` instead (see GAUDISW-247825). + +* ``vllm.v1.sample.ops.logprobs.batched_count_greater_than`` — upstream + decorates this function with ``@torch.compile(dynamic=True, ...)``. + Habana's ``recipe_compiler`` backend cannot handle the symbolic shapes + produced by ``dynamic=True`` (and by ``mark_unbacked`` in the caller), + raising ``TypeError: Cannot convert symbols to int``. We replace it + with a plain (uncompiled) version of the same function. The replacement + is deferred to ``load_general_plugins`` time to avoid importing + ``vllm.v1.sample.sampler`` during early plugin registration, which would + trigger a heavy import chain that interferes with platform initialisation. """ import functools @@ -74,6 +84,31 @@ def _hpu_cleanup_dist_env_and_memory(shutdown_ray: bool = False) -> None: parallel_state.logger.warning("torch._C._host_emptyCache() only available in Pytorch >=2.5") +def _hpu_batched_count_greater_than(x: torch.Tensor, values: torch.Tensor) -> torch.Tensor: + """HPU-safe replacement for ``batched_count_greater_than``. + + Identical logic to the upstream implementation but *not* wrapped in + ``torch.compile``. The upstream decorator uses ``dynamic=True`` whose + symbolic shapes are incompatible with Habana's ``recipe_compiler`` + backend, and ``mark_unbacked`` in the caller prevents ``dynamic=False`` + from helping. + """ + return (x >= values).sum(-1) + + +def _patch_batched_count_greater_than() -> None: + """Replace ``batched_count_greater_than`` in the sampler & logprobs modules. + + Called from the ``load_general_plugins`` hook so that the heavy + ``vllm.v1.sample.*`` import chain runs *after* platform initialisation. + """ + import vllm.v1.sample.ops.logprobs as _logprobs_mod + import vllm.v1.sample.sampler as _sampler_mod + + _logprobs_mod.batched_count_greater_than = _hpu_batched_count_greater_than + _sampler_mod.batched_count_greater_than = _hpu_batched_count_greater_than + + def apply() -> None: """Install all HPU runtime monkey-patches.""" # --- torch.accelerator.empty_cache --- @@ -83,15 +118,28 @@ def apply() -> None: if not hasattr(torch._C, "_host_emptyCache"): torch._C._host_emptyCache = lambda: None - # Patch the canonical definition. + # --- cleanup_dist_env_and_memory --- parallel_state.cleanup_dist_env_and_memory = _hpu_cleanup_dist_env_and_memory - # Patch the re-export from ``vllm.distributed`` so ``from vllm.distributed - # import cleanup_dist_env_and_memory`` (used by the upstream pytest - # conftest) also picks up the HPU-safe version. import vllm.distributed as _vllm_distributed _vllm_distributed.cleanup_dist_env_and_memory = _hpu_cleanup_dist_env_and_memory + # --- batched_count_greater_than (deferred) --- + # We cannot import the sampler modules here because the import chain + # triggers platform re-initialisation ("Device string must not be + # empty"). Instead we hook into ``load_general_plugins`` which runs + # in every process (parent + EngineCore subprocess) after the platform + # is ready. + import vllm.plugins as _plugins_mod + + _original_load_general = _plugins_mod.load_general_plugins + + def _load_general_with_hpu_patches(): + _original_load_general() + _patch_batched_count_greater_than() + + _plugins_mod.load_general_plugins = _load_general_with_hpu_patches + def patch_hf3fs_mock_client(): """Guard CUDA sync in the HF3FS mock client on non-CUDA platforms.