Skip to content

[FIX_FOR_VLLM_CUSTOM=14acf429ac08b6d538ca6feb3e06b6d13895804d] Fix CPUOffloadingSpec import path and remove obsolete roberta patch#1229

Merged
iboiko-habana merged 5 commits intovllm-project:mainfrom
pawel-olejniczak:fix/vllm-hourly-cpu-offload-roberta
Mar 30, 2026
Merged

[FIX_FOR_VLLM_CUSTOM=14acf429ac08b6d538ca6feb3e06b6d13895804d] Fix CPUOffloadingSpec import path and remove obsolete roberta patch#1229
iboiko-habana merged 5 commits intovllm-project:mainfrom
pawel-olejniczak:fix/vllm-hourly-cpu-offload-roberta

Conversation

@pawel-olejniczak
Copy link
Copy Markdown
Contributor

@pawel-olejniczak pawel-olejniczak commented Mar 24, 2026

Summary

Multiple upstream vLLM changes broke the hourly CI (RED since 2026-03-23 13:07 UTC):

  1. CPUOffloadingSpec import path — upstream PR #37874 refactored cpu.py into a cpu/ package
  2. replace_roberta_positions removed — upstream PR #37884
  3. vllm_is_batch_invariant removed — replaced with envs.VLLM_BATCH_INVARIANT
  4. key_cache guard for None — upstream decode path can pass None
  5. Synapse SDPA error 400 — continuation prefills triggered Synapse errors
  6. Attention.kv_cache list-to-element refactor — upstream PR #37487 (c59a132f9) changed Attention.kv_cache from list to tensor. HPU code used self.kv_cache[0] producing garbage output.

Changes

  • cpu_hpu.py: Updated CPUOffloadingSpec import path
  • models/roberta.py: Removed obsolete monkey-patch
  • init.py: Removed roberta import
  • vllm_gaudi_batch_invariant.py: Replace vllm_is_batch_invariant with envs.VLLM_BATCH_INVARIANT
  • ops/hpu_paged_attn.py: Guard decode path against None key_cache
  • attention/hpu_attn.py: Fix SDPA padding for continuation prefills
  • ops/hpu_attention.py: self.kv_cache[0] -> self.kv_cache (fix Fix CI fail hang #6)
  • attention/oot_mla.py: self.kv_cache[0] -> self.kv_cache (fix Fix CI fail hang #6)

Impact

Fixes ALL 50+ e2e test failures and restores correct model output on HPU.


AI-assisted: All changes reviewed and verified on HPU hardware.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes breakages caused by recent upstream vLLM refactors by updating import paths and removing an obsolete RoBERTa monkey-patch.

Changes:

  • Update CPUOffloadingSpec import to the new vllm.v1.kv_offload.cpu.spec module location.
  • Remove the now-obsolete RoBERTa forward monkey-patch and stop importing it during model registration.
  • Leave a short note in roberta.py explaining why the patch was removed.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
vllm_gaudi/v1/kv_offload/worker/cpu_hpu.py Switches to the new upstream import path for CPUOffloadingSpec.
vllm_gaudi/models/roberta.py Removes the previous monkey-patch implementation and replaces it with an explanatory note.
vllm_gaudi/init.py Stops importing the removed RoBERTa patch module during model registration.

Comment thread vllm_gaudi/v1/kv_offload/worker/cpu_hpu.py
Comment thread vllm_gaudi/models/roberta.py Outdated
@pawel-olejniczak pawel-olejniczak force-pushed the fix/vllm-hourly-cpu-offload-roberta branch from 248cbdd to 7377dd7 Compare March 24, 2026 12:03
@pawel-olejniczak pawel-olejniczak changed the title [FIX_FOR_VLLM_LATEST] Fix CPUOffloadingSpec import path and remove obsolete roberta patch [FIX_FOR_VLLM_CUSTOM=14acf429ac08b6d538ca6feb3e06b6d13895804d] Fix CPUOffloadingSpec import path and remove obsolete roberta patch Mar 24, 2026
if token_type_ids is not None:
assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
assert input_ids is not None

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

roberta.py was added in #1001. it was special handling of _encode_token_type_ids. after removal of roberta.py, _encode_token_type_ids(input_ids, token_type_ids) will be used from upstream forward function. let's wait for roberta models test's results

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have similar concerns here. Let’s wait for the test results.

@github-actions
Copy link
Copy Markdown

🚧 CI Blocked

The main CI workflow was not started for the following reason:

Your branch is behind the base branch. Please merge or rebase to get the latest changes.

@pawel-olejniczak pawel-olejniczak force-pushed the fix/vllm-hourly-cpu-offload-roberta branch 7 times, most recently from 12464ac to 016dcc3 Compare March 30, 2026 07:33
…solete roberta patch

- Update CPUOffloadingSpec import from vllm.v1.kv_offload.cpu to
  vllm.v1.kv_offload.cpu.spec (upstream PR #37874 refactored cpu.py
  into a cpu/ package)
- Remove roberta monkey-patch that called the now-deleted
  replace_roberta_positions function (upstream PR #37884 moved the
  position offset adjustment into RobertaEmbedding.forward())
- Remove corresponding roberta import from register_models()

Signed-off-by: Pawel Olejniczak <pawelx.olejniczak@intel.com>
…e removed vllm_is_batch_invariant with envs.VLLM_BATCH_INVARIANT

Upstream vLLM PR #35007 removed the vllm_is_batch_invariant() function
from batch_invariant.py, replacing it with a direct envs read.
Update vllm-gaudi to match.

Signed-off-by: Pawel Olejniczak <pawelx.olejniczak@intel.com>
Co-authored-by: GitHub Copilot
…decode path against None key_cache

During V1 warmup with LoRA or KV-offloading, the decode path can be
called before KV caches are bound. flat_pa crashes with
AttributeError on key_cache.shape when key_cache is None.

Add a None check in the decode path of HPUAttentionImpl.forward to
return zeros when key_cache is not available, matching the defensive
pattern already used in the prompt path.

Signed-off-by: Pawel Olejniczak <pawelx.olejniczak@intel.com>
Co-authored-by: GitHub Copilot
…_cache access after upstream list-to-element refactor

Upstream vLLM commit c59a132f9 (#37487) changed Attention.kv_cache from a
list of tensors to a single tensor. The HPU attention and MLA attention code
accessed self.kv_cache[0] which now returns the first sub-tensor slice
instead of the intended KV cache tensor, causing corrupted inference results.

Fix: Replace self.kv_cache[0] with self.kv_cache in both affected files.

Signed-off-by: Pawel Olejniczak <pawelx.olejniczak@intel.com>
@pawel-olejniczak pawel-olejniczak force-pushed the fix/vllm-hourly-cpu-offload-roberta branch from 016dcc3 to ea814d0 Compare March 30, 2026 12:14
…_cache indexing in Qwen3.5 GatedDeltaNet

self.kv_cache is already a tuple (conv_state, ssm_state) assigned
by the HPU model runner. The redundant intermediate index
self.kv_cache[0][0/1] collapsed conv_state from 3-D to 2-D,
causing an IndexError during Dynamo tracing.

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
@pawel-olejniczak pawel-olejniczak force-pushed the fix/vllm-hourly-cpu-offload-roberta branch from ea814d0 to 963be20 Compare March 30, 2026 12:16
@github-actions
Copy link
Copy Markdown

✅ CI Passed

All checks passed successfully against the following vllm commit:
14acf429ac08b6d538ca6feb3e06b6d13895804d

@iboiko-habana iboiko-habana merged commit 0fffded into vllm-project:main Mar 30, 2026
78 of 90 checks passed
adobrzyn added a commit that referenced this pull request Apr 1, 2026
…] Fix CPUOffloadingSpec import path and remove obsolete roberta patch (#1229)"

This reverts commit 0fffded.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants