From adb7220a9bfd05edf72e588aa9b776d14352de83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Olejniczak?= Date: Thu, 9 Apr 2026 17:27:38 +0300 Subject: [PATCH 1/5] Fix HPU attention forward guard for removed use_output attribute MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Paweł Olejniczak --- vllm_gaudi/ops/hpu_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_gaudi/ops/hpu_attention.py b/vllm_gaudi/ops/hpu_attention.py index 5afcb2c25..fc7822df3 100644 --- a/vllm_gaudi/ops/hpu_attention.py +++ b/vllm_gaudi/ops/hpu_attention.py @@ -28,7 +28,7 @@ def patched_attention_forward( context using `vllm.forward_context.get_forward_context().attn_metadata`. """ - if self.use_output or not self.use_direct_call: + if not self.use_direct_call: return layer.Attention._vllm_gaudi_original_forward(self, query, key, value, output_shape=output_shape) if self.calculate_kv_scales: From c2ae5a6864bccb4e5cc065179d63322c89ba4be6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Olejniczak?= Date: Thu, 9 Apr 2026 17:27:56 +0300 Subject: [PATCH 2/5] Fix offloading connector tests for upstream API changes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Paweł Olejniczak --- .../offloading_connector/test_scheduler.py | 15 +++++++-------- .../kv_offload/offloading_connector/utils.py | 10 ++++++---- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/unit_tests/kv_offload/offloading_connector/test_scheduler.py b/tests/unit_tests/kv_offload/offloading_connector/test_scheduler.py index 29d33ced3..374902609 100644 --- a/tests/unit_tests/kv_offload/offloading_connector/test_scheduler.py +++ b/tests/unit_tests/kv_offload/offloading_connector/test_scheduler.py @@ -8,8 +8,7 @@ generate_store_output, ) from tests.unit_tests.kv_offload.utils import EOS_TOKEN_ID from vllm.distributed.kv_events import BlockRemoved, BlockStored -from vllm.v1.core.kv_cache_utils import BlockHash -from vllm.v1.kv_offload.abstract import OffloadingEvent +from vllm.v1.kv_offload.abstract import OffloadingEvent, OffloadKey, make_offload_key from vllm.v1.request import RequestStatus @@ -108,19 +107,19 @@ def test_offloading_connector(request_runner, async_scheduling: bool): runner.run(decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(3, 4, 5)) # test take_events - def to_hashes(int_hashes: list[int]) -> list[BlockHash]: - return [BlockHash(str(i).encode()) for i in int_hashes] + def to_keys(int_ids: list[int]) -> list[OffloadKey]: + return [make_offload_key(str(i).encode(), 0) for i in int_ids] def take_events() -> Iterable[OffloadingEvent]: - yield OffloadingEvent(block_hashes=to_hashes([1, 2, 3]), block_size=16, medium="A", removed=False) - yield OffloadingEvent(block_hashes=to_hashes([4, 5, 6]), block_size=32, medium="B", removed=True) + yield OffloadingEvent(keys=to_keys([1, 2, 3]), block_size=16, medium="A", removed=False) + yield OffloadingEvent(keys=to_keys([4, 5, 6]), block_size=32, medium="B", removed=True) runner.manager.take_events.side_effect = take_events events = list(runner.scheduler_connector.take_events()) assert len(events) == 2 event = events[0] assert isinstance(event, BlockStored) - assert event.block_hashes == to_hashes([1, 2, 3]) + assert event.block_hashes == [str(i).encode() for i in [1, 2, 3]] assert event.block_size == 16 assert event.medium == "A" assert event.token_ids == [] @@ -129,7 +128,7 @@ def take_events() -> Iterable[OffloadingEvent]: assert event.lora_name is None event = events[1] assert isinstance(event, BlockRemoved) - assert event.block_hashes == to_hashes([4, 5, 6]) + assert event.block_hashes == [str(i).encode() for i in [4, 5, 6]] assert event.medium == "B" diff --git a/tests/unit_tests/kv_offload/offloading_connector/utils.py b/tests/unit_tests/kv_offload/offloading_connector/utils.py index d96efec08..c9a9483f6 100644 --- a/tests/unit_tests/kv_offload/offloading_connector/utils.py +++ b/tests/unit_tests/kv_offload/offloading_connector/utils.py @@ -216,8 +216,10 @@ def __init__(self, assert isinstance(manager, MagicMock) self.manager: MagicMock = manager - assert connector_scheduler.gpu_block_size == gpu_block_size - assert connector_scheduler.offloaded_block_size == offloaded_block_size + assert len(connector_scheduler.config.kv_group_configs) == 1 + kv_group_config = connector_scheduler.config.kv_group_configs[0] + assert kv_group_config.gpu_block_size == gpu_block_size + assert kv_group_config.offloaded_block_size == offloaded_block_size # extract OffloadingSpec of worker_connector connector_worker = self.worker_connector.connector_worker @@ -459,7 +461,7 @@ def runner_factory(offloaded_block_size, gpu_block_size, num_gpu_blocks, async_s def generate_store_output(block_hashes: Iterable[BlockHash]): block_hashes = list(block_hashes) return PrepareStoreOutput( - block_hashes_to_store=list(block_hashes), + keys_to_store=list(block_hashes), store_spec=MockLoadStoreSpec(block_hashes), - block_hashes_evicted=[], + evicted_keys=[], ) From b4cc45aeaa30dec978e941efacb593e77b2e234d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Olejniczak?= Date: Thu, 9 Apr 2026 17:28:11 +0300 Subject: [PATCH 3/5] Fix FP8 block kernel registration and ops test config MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Paweł Olejniczak --- tests/unit_tests/ops/conftest.py | 24 ++++++++++++++++++++++++ vllm_gaudi/ops/hpu_fp8.py | 27 ++++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 tests/unit_tests/ops/conftest.py diff --git a/tests/unit_tests/ops/conftest.py b/tests/unit_tests/ops/conftest.py new file mode 100644 index 000000000..d6509ff1a --- /dev/null +++ b/tests/unit_tests/ops/conftest.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace + +import pytest +import torch + +from vllm.config import VllmConfig, set_current_vllm_config + + +@pytest.fixture +def default_vllm_config(): + """VllmConfig with a minimal model_config stub. + + Upstream Fp8LinearMethod.__init__ accesses model_config.dtype. + We provide a SimpleNamespace with the attributes required by quantization + methods so that ops-level unit tests can run without a full model setup. + """ + vllm_config = VllmConfig() + vllm_config.model_config = SimpleNamespace(dtype=torch.bfloat16, is_moe=False, hf_config=None, quantization=None) + + with set_current_vllm_config(vllm_config): + yield diff --git a/vllm_gaudi/ops/hpu_fp8.py b/vllm_gaudi/ops/hpu_fp8.py index 64bcd46e6..82031729d 100644 --- a/vllm_gaudi/ops/hpu_fp8.py +++ b/vllm_gaudi/ops/hpu_fp8.py @@ -16,8 +16,9 @@ from vllm_gaudi.ops.hpu_fused_moe import _normalize_moe_activation from vllm_gaudi.v1.worker.hpu_dp_utils import dispatch_hidden_states, dispatch_tensor, get_hpu_dp_metadata -from vllm.model_executor.kernels.linear import _POSSIBLE_FP8_KERNELS +from vllm.model_executor.kernels.linear import _POSSIBLE_FP8_BLOCK_KERNELS, _POSSIBLE_FP8_KERNELS from vllm.platforms import PlatformEnum +from vllm.model_executor.kernels.linear.scaled_mm.BlockScaledMMLinearKernel import Fp8BlockScaledMMLinearKernel from vllm.model_executor.kernels.linear.scaled_mm.pytorch import ( PerTensorTorchFP8ScaledMMLinearKernel, ChannelWiseTorchFP8ScaledMMLinearKernel, @@ -38,12 +39,36 @@ def is_supported(cls, compute_capability: int | None = None) -> tuple[bool, str return True, None +class HPUFp8BlockScaledMMLinearKernel(Fp8BlockScaledMMLinearKernel): + """HPU stub for block-scaled FP8 linear. + + The actual computation is handled by HPU-specific ops in + Fp8LinearMethod.apply(), so this kernel only needs to satisfy + the kernel selection interface. + """ + + @classmethod + def is_supported(cls, compute_capability: int | None = None) -> tuple[bool, str | None]: + return True, None + + def apply_weights(self, layer, x, bias=None): + raise NotImplementedError("HPU uses Fp8LinearMethod.apply() directly") + + def apply_block_scaled_mm(self, A, B, As, Bs): + raise NotImplementedError("HPU uses Fp8LinearMethod.apply() directly") + + if PlatformEnum.OOT not in _POSSIBLE_FP8_KERNELS: _POSSIBLE_FP8_KERNELS[PlatformEnum.OOT] = [ HPUPerTensorTorchFP8ScaledMMLinearKernel, HPUChannelWiseTorchFP8ScaledMMLinearKernel, ] +if PlatformEnum.OOT not in _POSSIBLE_FP8_BLOCK_KERNELS: + _POSSIBLE_FP8_BLOCK_KERNELS[PlatformEnum.OOT] = [ + HPUFp8BlockScaledMMLinearKernel, + ] + class Fp8LinearMethod(OrigFp8LinearMethod): From 0456f22156fcd4514c4466da0a6af14414d24895 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Olejniczak?= Date: Thu, 9 Apr 2026 17:50:27 +0300 Subject: [PATCH 4/5] Fix HPU MLA attention for removed accept_output_buffer attribute MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Paweł Olejniczak --- vllm_gaudi/attention/oot_mla.py | 46 ++++++++------------------------- 1 file changed, 11 insertions(+), 35 deletions(-) diff --git a/vllm_gaudi/attention/oot_mla.py b/vllm_gaudi/attention/oot_mla.py index 5f14c7a82..becc17713 100644 --- a/vllm_gaudi/attention/oot_mla.py +++ b/vllm_gaudi/attention/oot_mla.py @@ -69,19 +69,7 @@ def forward( # self.kv_cache_dtype, # self._k_scale, #) - if self.attn_backend.accept_output_buffer: - output = torch.empty(output_shape, dtype=q.dtype, device=q.device) - self.forward_impl( - q, - kv_c_normed, - k_pe, - self_kv_cache, - attn_metadata, - output=output, - ) - return output - else: - return self.forward_impl(q, kv_c_normed, k_pe, self_kv_cache, attn_metadata) + return self.forward_impl(q, kv_c_normed, k_pe, self_kv_cache, attn_metadata) else: kv_cache_dummy_dep = torch.ops.vllm.unified_mla_kv_cache_update( kv_c_normed, @@ -90,25 +78,16 @@ def forward( self.kv_cache_dtype, self._k_scale, ) - if self.attn_backend.accept_output_buffer: - output = torch.empty(output_shape, dtype=q.dtype, device=q.device) - torch.ops.vllm.unified_mla_attention_with_output( - q, - kv_c_normed, - k_pe, - output, - self.layer_name, - kv_cache_dummy_dep=kv_cache_dummy_dep, - ) - return output - else: - return torch.ops.vllm.unified_mla_attention( - q, - kv_c_normed, - k_pe, - self.layer_name, - kv_cache_dummy_dep=kv_cache_dummy_dep, - ) + output = torch.empty(output_shape, dtype=q.dtype, device=q.device) + torch.ops.vllm.unified_mla_attention_with_output( + q, + kv_c_normed, + k_pe, + output, + self.layer_name, + kv_cache_dummy_dep=kv_cache_dummy_dep, + ) + return output def forward_impl( self, @@ -121,9 +100,6 @@ def forward_impl( output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: - if output is not None: - raise NotImplementedError("output is not yet supported for MLAImplBase") - is_prefill = attn_metadata.is_prompt if not is_prefill: From ae16ecb35097df10b9336f24ca61309d03f46dc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Olejniczak?= Date: Fri, 10 Apr 2026 15:16:11 +0300 Subject: [PATCH 5/5] Add manual_seed_all to HpuPlatform MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Paweł Olejniczak --- vllm_gaudi/platform.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm_gaudi/platform.py b/vllm_gaudi/platform.py index a3b0cf534..8599b021a 100644 --- a/vllm_gaudi/platform.py +++ b/vllm_gaudi/platform.py @@ -87,6 +87,10 @@ def set_device(cls, device: torch.device) -> None: """ return + @classmethod + def manual_seed_all(cls, seed: int) -> None: + torch.hpu.random.manual_seed_all(seed) + @classmethod def get_device_name(cls, device_id: int = 0) -> str: return cls.device_name