Skip to content

[FIX_FOR_VLLM_CUSTOM=f976e3b98ba45677a2213673a442c6cbff141e8e] Fix upstream regressions in attention, FP8, offloading and platform#1338

Merged
iboiko-habana merged 5 commits intovllm-project:mainfrom
pawel-olejniczak:fix/vllm-hourly-10-4
Apr 13, 2026
Merged

[FIX_FOR_VLLM_CUSTOM=f976e3b98ba45677a2213673a442c6cbff141e8e] Fix upstream regressions in attention, FP8, offloading and platform#1338
iboiko-habana merged 5 commits intovllm-project:mainfrom
pawel-olejniczak:fix/vllm-hourly-10-4

Conversation

@pawel-olejniczak
Copy link
Copy Markdown
Contributor

Summary

Fixes five regressions introduced by recent upstream vLLM changes that break HPU unit tests and model execution.

Changes

  1. Remove use_output guard from HPU attention patch — attribute removed upstream
  2. Remove accept_output_buffer branching from HPU MLA attention — attribute removed upstream; unconditionally use output buffer in opaque path, direct call path manages output internally
  3. Update KV offloading connector tests — field renames: block_hasheskeys, block_hashes_to_storekeys_to_store, config access via kv_group_configs[0]
  4. Register HPU FP8 block-scaled kernel + add ops test conftest — new _POSSIBLE_FP8_BLOCK_KERNELS dict needs OOT entry; provide VllmConfig stub for ops unit tests
  5. Add manual_seed_all to HpuPlatform — new required platform method for RNG seeding

Upstream PRs that introduced these regressions

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes multiple breakages caused by upstream vLLM API changes, keeping Gaudi/HPU attention, FP8, KV offloading tests, and platform integration compatible with the new interfaces.

Changes:

  • Update HPU attention (regular + MLA) to align with upstream removal of use_output / accept_output_buffer.
  • Register an HPU FP8 block-scaled kernel stub and add an ops test conftest providing a minimal VllmConfig context.
  • Update KV offloading connector unit tests for upstream event/output field renames and config layout changes.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
vllm_gaudi/platform.py Adds manual_seed_all required by upstream Platform API.
vllm_gaudi/ops/hpu_fp8.py Registers an OOT entry for _POSSIBLE_FP8_BLOCK_KERNELS via an HPU stub kernel class.
vllm_gaudi/ops/hpu_attention.py Removes dependency on upstream-removed use_output attribute in attention patching logic.
vllm_gaudi/attention/oot_mla.py Removes accept_output_buffer branching and standardizes output-buffer usage in the opaque path.
tests/unit_tests/ops/conftest.py Introduces a fixture that sets a minimal current VllmConfig for ops unit tests.
tests/unit_tests/kv_offload/offloading_connector/utils.py Updates scheduler config assertions and adapts mocked PrepareStoreOutput to renamed fields.
tests/unit_tests/kv_offload/offloading_connector/test_scheduler.py Updates offloading event tests to use OffloadKey/make_offload_key and new event fields.

Comment on lines 461 to 467
def generate_store_output(block_hashes: Iterable[BlockHash]):
block_hashes = list(block_hashes)
return PrepareStoreOutput(
block_hashes_to_store=list(block_hashes),
keys_to_store=list(block_hashes),
store_spec=MockLoadStoreSpec(block_hashes),
block_hashes_evicted=[],
evicted_keys=[],
)
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generate_store_output and MockLoadStoreSpec still use block_hashes naming, but the returned PrepareStoreOutput now uses the renamed keys_to_store / evicted_keys fields. Consider renaming the function parameter/local variables (and related mock spec fields, if appropriate) to keys to match the updated API and avoid confusion when reading or extending these tests.

Copilot uses AI. Check for mistakes.
@@ -121,9 +100,6 @@ def forward_impl(
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forward_impl still accepts output / output_scale / output_block_scale, but the implementation ignores these parameters and will overwrite output locally. Since the earlier explicit NotImplementedError guard was removed, callers that pass an output buffer could now get silently incorrect behavior. Consider either (a) restoring an explicit error when output is provided, or (b) implementing true output-buffer support (writing into the provided tensor) and documenting the contract.

Suggested change
) -> torch.Tensor:
) -> torch.Tensor:
if (output is not None or output_scale is not None
or output_block_scale is not None):
raise NotImplementedError(
"HPUMLAAttention.forward_impl does not support caller-"
"provided output, output_scale, or output_block_scale.")

Copilot uses AI. Check for mistakes.
@github-actions
Copy link
Copy Markdown

✅ CI Passed

All checks passed successfully against the following vllm commit:
f976e3b98ba45677a2213673a442c6cbff141e8e

@iboiko-habana iboiko-habana merged commit c7b510f into vllm-project:main Apr 13, 2026
74 of 75 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants