Skip to content

[FIX_FOR_VLLM_CUSTOM=5b39b268f506150dbab38f6f6c04b7c843e37c07] Fix upstream regressions: MoE refactor, DeepSeek V4 router, KV offload HMA#1403

Merged
iboiko-habana merged 10 commits into
vllm-project:mainfrom
pawel-olejniczak:fix/vllm-hourly-29-4
May 6, 2026
Merged

[FIX_FOR_VLLM_CUSTOM=5b39b268f506150dbab38f6f6c04b7c843e37c07] Fix upstream regressions: MoE refactor, DeepSeek V4 router, KV offload HMA#1403
iboiko-habana merged 10 commits into
vllm-project:mainfrom
pawel-olejniczak:fix/vllm-hourly-29-4

Conversation

@pawel-olejniczak
Copy link
Copy Markdown
Collaborator

@pawel-olejniczak pawel-olejniczak commented Apr 29, 2026

Fix multiple upstream vLLM regressions breaking vllm-gaudi unit tests at vllm commit 5b39b268f506150dbab38f6f6c04b7c843e37c07

Fixes

1. MoE runner import rename (upstream #40560)

moe_runner_base.py was removed and split into moe_runner.py + moe_runner_interface.py. MoERunnerBase class renamed to MoERunner.

Changes:

  • vllm_gaudi/ops/hpu_fused_moe.py: Import MoERunner as MoERunnerBase from moe_runner module; update get_layer_from_name import path
  • vllm_gaudi/ops/hpu_lora.py: Change _all_lora_classes from set to tuple (upstream #35077)

2. DeepSeek V4 router API — hash_indices_table (upstream #40860)

FusedMoE.__init__ now passes hash_indices_table and zero_expert_type to create_fused_moe_router(), and input_ids kwarg to apply_monolithic().

Changes:

  • vllm_gaudi/ops/hpu_fused_moe.py: Add hash_indices_table parameter to HPU create_fused_moe_router() override; pass it to FusedTopKBiasRouter; add scoring_func assertion
  • vllm_gaudi/ops/hpu_compressed_tensors.py: Add **kwargs to HPUCompressedTensorsWNA16MoEMethod.apply_monolithic() for input_ids kwarg
  • tests/unit_tests/ops/utils.py: Add zero_expert_type and hash_indices_table to create_fused_moe() test helper

3. MoE test _forward_dispatch removal (upstream #40560)

MoERunnerBase._forward_dispatch() was removed. Tests must use runner.forward() with a proper ForwardContext.

Changes:

  • tests/unit_tests/ops/test_hpu_fused_moe.py: Replace _forward_dispatch call with runner.forward(); use real ForwardContext with no_compile_layers
  • tests/unit_tests/ops/test_hpu_compressed_tensors.py: Same migration

4. KV offload scheduler — HMA multi-group + per-job store completion (upstream #39186, #39403, #38453, #39401, #39402)

Offloading scheduler was refactored for multi-group KV support and per-job store completion tracking.

Changes:

  • tests/unit_tests/kv_offload/offloading_connector/utils.py: Sync with upstream — OffloadKey/OffloadingWorkerMetadata types, async scheduling support, TransferJobStatus tracking, build_connector_worker_meta() integration
  • tests/unit_tests/kv_offload/utils.py: Add kv_connector_worker_meta parameter to create_model_runner_output()

5. OffloadingConnectorMetadata per-job API in model runner (upstream #39186)

OffloadingConnectorMetadata fields were renamed from reqs_to_store/reqs_to_load (dict[ReqId, TransferSpec]) to store_jobs/load_jobs (dict[int, TransferJob]), with req_id now inside TransferJob. This caused EngineDeadError crashes in all tests that use KV offloading or LoRA (3 CI failures).

Changes:

  • vllm_gaudi/v1/worker/hpu_model_runner.py: Update _get_prompts_and_decodes() to extract req_id from TransferJob objects in store_jobs/load_jobs instead of iterating over removed reqs_to_store/reqs_to_load — both for direct OffloadingConnectorMetadata and nested MultiKVConnectorMetadata cases

6. LoRA punica wrapper — add_shrink / add_expand (upstream #35077)

Upstream refactored PunicaWrapperBase to add add_shrink() and add_expand() methods. HPU punica wrapper was missing these, causing AttributeError.

Changes:

  • vllm_gaudi/lora/hpu_punica_wrapper.py: Implement add_shrink() and add_expand() methods in HPU punica wrapper

7. Rejection sampler — synthetic_mode kwarg (upstream #40662)

Upstream added synthetic_mode parameter to rejection_sample(). HPU override was missing it.

Changes:

  • vllm_gaudi/v1/worker/hpu_model_runner.py: Accept synthetic_mode kwarg in HPU rejection_sample override

8. MoE forward — pass input_ids to custom op (upstream #40860)

DeepSeek V4 support added input_ids parameter at position 3 of _moe_forward / _moe_forward_shared custom ops. patched_fused_moe_forward was not passing it through to _forward_impl and _forward_entry, causing RuntimeError: Expected Optional[Tensor] but found str in hpu_dp_tests.

Changes:

  • vllm_gaudi/ops/hpu_fused_moe.py: Pass input_ids through both _forward_impl and _forward_entry calls in patched_fused_moe_forward()

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes multiple vLLM upstream API/regression changes (MoE runner refactor, DeepSeek V4 router args, KV offload scheduler refactor) to restore vllm-gaudi unit test compatibility at vLLM commit 5b39b268f506150dbab38f6f6c04b7c843e37c07.

Changes:

  • Update HPU MoE integration for upstream runner/router API changes (including hash_indices_table and router selection behavior).
  • Migrate MoE-related unit tests from removed _forward_dispatch to runner.forward() with a real ForwardContext.
  • Sync KV offload test utilities with upstream multi-group/async scheduling + per-job completion tracking changes.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
vllm_gaudi/ops/hpu_lora.py Adapts LoRA registration/patching for upstream _all_lora_classes container change.
vllm_gaudi/ops/hpu_fused_moe.py Updates MoE runner imports and extends router factory to accept hash_indices_table for DeepSeek V4.
vllm_gaudi/ops/hpu_compressed_tensors.py Loosens apply_monolithic signature to accept new upstream kwargs (e.g., input_ids).
tests/unit_tests/ops/utils.py Updates create_fused_moe() helper to pass new MoE ctor args.
tests/unit_tests/ops/test_hpu_fused_moe.py Updates MoE test to use ForwardContext + runner.forward().
tests/unit_tests/ops/test_hpu_compressed_tensors.py Updates MoE test similarly to use ForwardContext + runner.forward().
tests/unit_tests/kv_offload/utils.py Extends test helper output to include kv_connector_worker_meta.
tests/unit_tests/kv_offload/offloading_connector/utils.py Syncs offloading connector test harness with upstream async scheduling and keying changes.

@@ -111,7 +120,8 @@ def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig):

self.manager = MagicMock(spec=OffloadingManager)
self.manager.lookup.return_value = 0
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MockOffloadingSpec sets manager.lookup.return_value twice (first to 0, then to False). This looks accidental and makes it unclear what the mocked contract is supposed to be. Please remove the redundant assignment and keep a single return type/value consistent with OffloadingManager.lookup's expected behavior.

Suggested change
self.manager.lookup.return_value = 0

Copilot uses AI. Check for mistakes.
Comment on lines +60 to +61
vllm.lora.utils._all_lora_classes = tuple(
HPUVocabParallelEmbeddingWithLoRA if cls is VocabParallelEmbeddingWithLoRA else cls for cls in _all_lora_classes)
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switching vllm.lora.utils._all_lora_classes to a tuple here makes it immutable, but other code in this repo still treats it as a mutable set (e.g., vllm_gaudi/lora/layers/hpu_row_parallel_linear.py:64-69 calls discard/add). When VLLM_ROW_PARALLEL_CHUNKS > 1, register_hpu_lora_layers() will raise AttributeError. Consider updating all call sites to work with the tuple-based API (e.g., rebuild a new tuple with replacements/removals) or provide a small helper here that performs tuple-safe updates and is reused everywhere.

Suggested change
vllm.lora.utils._all_lora_classes = tuple(
HPUVocabParallelEmbeddingWithLoRA if cls is VocabParallelEmbeddingWithLoRA else cls for cls in _all_lora_classes)
def _replace_lora_class(classes, old_cls, new_cls):
updated_classes = [
new_cls if cls is old_cls else cls for cls in classes
]
if isinstance(classes, tuple):
return tuple(updated_classes)
if isinstance(classes, set):
return set(updated_classes)
return type(classes)(updated_classes)
vllm.lora.utils._all_lora_classes = _replace_lora_class(
_all_lora_classes,
VocabParallelEmbeddingWithLoRA,
HPUVocabParallelEmbeddingWithLoRA,
)

Copilot uses AI. Check for mistakes.
Comment on lines +493 to 496
assert scoring_func in ["sigmoid", "softmax", "sqrtsoftplus"]

if e_score_correction_bias is not None or hash_indices_table is not None:
return FusedTopKBiasRouter(
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring and priority-order comments are now out of sync with the implementation: scoring_func is asserted to allow "sqrtsoftplus", and FusedTopKBiasRouter is selected when hash_indices_table is provided (even if e_score_correction_bias is None). Please update the selection-order bullets and the scoring_func description in the docstring to match the new behavior.

Copilot uses AI. Check for mistakes.
Comment thread vllm_gaudi/ops/hpu_fused_moe.py
@pawel-olejniczak pawel-olejniczak force-pushed the fix/vllm-hourly-29-4 branch 2 times, most recently from 6c01231 to 342e65b Compare May 4, 2026 11:04
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 5, 2026

✅ CI Passed

All checks passed successfully against the following vllm commit:
5b39b268f506150dbab38f6f6c04b7c843e37c07

…A classes tuple (upstream #40560, #35077)

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
…k V4 router API (upstream #40860)

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
…ream per-job store completion (upstream #39186, #39403)

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
…_fused_moe_router (upstream #40860)

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
…h → forward (upstream #40560)

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
… per-job API (upstream #39186)

Upstream vLLM PR#39186 changed OffloadingConnectorMetadata from
reqs_to_store/reqs_to_load (dict[ReqId, TransferSpec]) to
store_jobs/load_jobs (dict[int, TransferJob]) with req_id inside
TransferJob. Update _get_prompts_and_decodes to extract req_ids
from the new structure.

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
…in HPU punica wrapper (upstream #35077)

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
…MoE runner refactor (upstream #40560)

- Handle is_internal_router being a read-only @Property on INC-wrapped
  FusedMoE modules (PatchedMixtralMoE) by checking for property
  descriptor before attempting attribute assignment.
- Accept new input_ids parameter in patched_fused_moe_forward to match
  upstream MoERunner.forward signature change.

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
…ejection_sample (upstream #40662)

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
…orward calls (upstream #40860)

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 5, 2026

🚧 CI Blocked

The main CI workflow was not started for the following reason:

Your branch is behind the base branch. Please merge or rebase to get the latest changes.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 5, 2026

✅ CI Passed

All checks passed successfully against the following vllm commit:
5b39b268f506150dbab38f6f6c04b7c843e37c07

@iboiko-habana iboiko-habana merged commit ef4400b into vllm-project:main May 6, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants