Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
generate_store_output, )
from tests.unit_tests.kv_offload.utils import EOS_TOKEN_ID
from vllm.distributed.kv_events import BlockRemoved, BlockStored
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.kv_offload.abstract import OffloadingEvent
from vllm.v1.kv_offload.abstract import OffloadingEvent, OffloadKey, make_offload_key
from vllm.v1.request import RequestStatus


Expand Down Expand Up @@ -108,19 +107,19 @@ def test_offloading_connector(request_runner, async_scheduling: bool):
runner.run(decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(3, 4, 5))

# test take_events
def to_hashes(int_hashes: list[int]) -> list[BlockHash]:
return [BlockHash(str(i).encode()) for i in int_hashes]
def to_keys(int_ids: list[int]) -> list[OffloadKey]:
return [make_offload_key(str(i).encode(), 0) for i in int_ids]

def take_events() -> Iterable[OffloadingEvent]:
yield OffloadingEvent(block_hashes=to_hashes([1, 2, 3]), block_size=16, medium="A", removed=False)
yield OffloadingEvent(block_hashes=to_hashes([4, 5, 6]), block_size=32, medium="B", removed=True)
yield OffloadingEvent(keys=to_keys([1, 2, 3]), block_size=16, medium="A", removed=False)
yield OffloadingEvent(keys=to_keys([4, 5, 6]), block_size=32, medium="B", removed=True)

runner.manager.take_events.side_effect = take_events
events = list(runner.scheduler_connector.take_events())
assert len(events) == 2
event = events[0]
assert isinstance(event, BlockStored)
assert event.block_hashes == to_hashes([1, 2, 3])
assert event.block_hashes == [str(i).encode() for i in [1, 2, 3]]
assert event.block_size == 16
assert event.medium == "A"
assert event.token_ids == []
Expand All @@ -129,7 +128,7 @@ def take_events() -> Iterable[OffloadingEvent]:
assert event.lora_name is None
event = events[1]
assert isinstance(event, BlockRemoved)
assert event.block_hashes == to_hashes([4, 5, 6])
assert event.block_hashes == [str(i).encode() for i in [4, 5, 6]]
assert event.medium == "B"


Expand Down
10 changes: 6 additions & 4 deletions tests/unit_tests/kv_offload/offloading_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,10 @@ def __init__(self,
assert isinstance(manager, MagicMock)
self.manager: MagicMock = manager

assert connector_scheduler.gpu_block_size == gpu_block_size
assert connector_scheduler.offloaded_block_size == offloaded_block_size
assert len(connector_scheduler.config.kv_group_configs) == 1
kv_group_config = connector_scheduler.config.kv_group_configs[0]
assert kv_group_config.gpu_block_size == gpu_block_size
assert kv_group_config.offloaded_block_size == offloaded_block_size

# extract OffloadingSpec of worker_connector
connector_worker = self.worker_connector.connector_worker
Expand Down Expand Up @@ -459,7 +461,7 @@ def runner_factory(offloaded_block_size, gpu_block_size, num_gpu_blocks, async_s
def generate_store_output(block_hashes: Iterable[BlockHash]):
block_hashes = list(block_hashes)
return PrepareStoreOutput(
block_hashes_to_store=list(block_hashes),
keys_to_store=list(block_hashes),
store_spec=MockLoadStoreSpec(block_hashes),
block_hashes_evicted=[],
evicted_keys=[],
)
Comment on lines 461 to 467
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generate_store_output and MockLoadStoreSpec still use block_hashes naming, but the returned PrepareStoreOutput now uses the renamed keys_to_store / evicted_keys fields. Consider renaming the function parameter/local variables (and related mock spec fields, if appropriate) to keys to match the updated API and avoid confusion when reading or extending these tests.

Copilot uses AI. Check for mistakes.
24 changes: 24 additions & 0 deletions tests/unit_tests/ops/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from types import SimpleNamespace

import pytest
import torch

from vllm.config import VllmConfig, set_current_vllm_config


@pytest.fixture
def default_vllm_config():
"""VllmConfig with a minimal model_config stub.

Upstream Fp8LinearMethod.__init__ accesses model_config.dtype.
We provide a SimpleNamespace with the attributes required by quantization
methods so that ops-level unit tests can run without a full model setup.
"""
vllm_config = VllmConfig()
vllm_config.model_config = SimpleNamespace(dtype=torch.bfloat16, is_moe=False, hf_config=None, quantization=None)

with set_current_vllm_config(vllm_config):
yield
46 changes: 11 additions & 35 deletions vllm_gaudi/attention/oot_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,7 @@ def forward(
# self.kv_cache_dtype,
# self._k_scale,
#)
if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
self.forward_impl(
q,
kv_c_normed,
k_pe,
self_kv_cache,
attn_metadata,
output=output,
)
return output
else:
return self.forward_impl(q, kv_c_normed, k_pe, self_kv_cache, attn_metadata)
return self.forward_impl(q, kv_c_normed, k_pe, self_kv_cache, attn_metadata)
else:
kv_cache_dummy_dep = torch.ops.vllm.unified_mla_kv_cache_update(
kv_c_normed,
Expand All @@ -90,25 +78,16 @@ def forward(
self.kv_cache_dtype,
self._k_scale,
)
if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
torch.ops.vllm.unified_mla_attention_with_output(
q,
kv_c_normed,
k_pe,
output,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return output
else:
return torch.ops.vllm.unified_mla_attention(
q,
kv_c_normed,
k_pe,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
torch.ops.vllm.unified_mla_attention_with_output(
q,
kv_c_normed,
k_pe,
output,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return output

def forward_impl(
self,
Expand All @@ -121,9 +100,6 @@ def forward_impl(
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forward_impl still accepts output / output_scale / output_block_scale, but the implementation ignores these parameters and will overwrite output locally. Since the earlier explicit NotImplementedError guard was removed, callers that pass an output buffer could now get silently incorrect behavior. Consider either (a) restoring an explicit error when output is provided, or (b) implementing true output-buffer support (writing into the provided tensor) and documenting the contract.

Suggested change
) -> torch.Tensor:
) -> torch.Tensor:
if (output is not None or output_scale is not None
or output_block_scale is not None):
raise NotImplementedError(
"HPUMLAAttention.forward_impl does not support caller-"
"provided output, output_scale, or output_block_scale.")

Copilot uses AI. Check for mistakes.
if output is not None:
raise NotImplementedError("output is not yet supported for MLAImplBase")

is_prefill = attn_metadata.is_prompt

if not is_prefill:
Expand Down
2 changes: 1 addition & 1 deletion vllm_gaudi/ops/hpu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def patched_attention_forward(
context using
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
if self.use_output or not self.use_direct_call:
if not self.use_direct_call:
return layer.Attention._vllm_gaudi_original_forward(self, query, key, value, output_shape=output_shape)

if self.calculate_kv_scales:
Expand Down
27 changes: 26 additions & 1 deletion vllm_gaudi/ops/hpu_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from vllm_gaudi.ops.hpu_fused_moe import _normalize_moe_activation
from vllm_gaudi.v1.worker.hpu_dp_utils import dispatch_hidden_states, dispatch_tensor, get_hpu_dp_metadata

from vllm.model_executor.kernels.linear import _POSSIBLE_FP8_KERNELS
from vllm.model_executor.kernels.linear import _POSSIBLE_FP8_BLOCK_KERNELS, _POSSIBLE_FP8_KERNELS
from vllm.platforms import PlatformEnum
from vllm.model_executor.kernels.linear.scaled_mm.BlockScaledMMLinearKernel import Fp8BlockScaledMMLinearKernel
from vllm.model_executor.kernels.linear.scaled_mm.pytorch import (
PerTensorTorchFP8ScaledMMLinearKernel,
ChannelWiseTorchFP8ScaledMMLinearKernel,
Expand All @@ -38,12 +39,36 @@ def is_supported(cls, compute_capability: int | None = None) -> tuple[bool, str
return True, None


class HPUFp8BlockScaledMMLinearKernel(Fp8BlockScaledMMLinearKernel):
"""HPU stub for block-scaled FP8 linear.

The actual computation is handled by HPU-specific ops in
Fp8LinearMethod.apply(), so this kernel only needs to satisfy
the kernel selection interface.
"""

@classmethod
def is_supported(cls, compute_capability: int | None = None) -> tuple[bool, str | None]:
return True, None

def apply_weights(self, layer, x, bias=None):
raise NotImplementedError("HPU uses Fp8LinearMethod.apply() directly")

def apply_block_scaled_mm(self, A, B, As, Bs):
raise NotImplementedError("HPU uses Fp8LinearMethod.apply() directly")


if PlatformEnum.OOT not in _POSSIBLE_FP8_KERNELS:
_POSSIBLE_FP8_KERNELS[PlatformEnum.OOT] = [
HPUPerTensorTorchFP8ScaledMMLinearKernel,
HPUChannelWiseTorchFP8ScaledMMLinearKernel,
]

if PlatformEnum.OOT not in _POSSIBLE_FP8_BLOCK_KERNELS:
_POSSIBLE_FP8_BLOCK_KERNELS[PlatformEnum.OOT] = [
HPUFp8BlockScaledMMLinearKernel,
]


class Fp8LinearMethod(OrigFp8LinearMethod):

Expand Down
4 changes: 4 additions & 0 deletions vllm_gaudi/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ def set_device(cls, device: torch.device) -> None:
"""
return

@classmethod
def manual_seed_all(cls, seed: int) -> None:
torch.hpu.random.manual_seed_all(seed)

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return cls.device_name
Expand Down
Loading