Skip to content

[FIX_FOR_VLLM_CUSTOM=fc701c80588c215f84af0b745edcf4d127e276bc] Fix upstream regressions in HPU worker, MoE router, and offloading tests#1354

Merged
adobrzyn merged 3 commits into
vllm-project:mainfrom
pawel-olejniczak:fix/vllm-hourly-15-4
Apr 16, 2026
Merged

[FIX_FOR_VLLM_CUSTOM=fc701c80588c215f84af0b745edcf4d127e276bc] Fix upstream regressions in HPU worker, MoE router, and offloading tests#1354
adobrzyn merged 3 commits into
vllm-project:mainfrom
pawel-olejniczak:fix/vllm-hourly-15-4

Conversation

@pawel-olejniczak
Copy link
Copy Markdown
Collaborator

@pawel-olejniczak pawel-olejniczak commented Apr 15, 2026

Fix three upstream regressions that break HPU unit tests.

Changes

  1. vllm_gaudi/v1/worker/hpu_worker.pycompile_or_warm_up_model() now returns
    CompilationTimes NamedTuple instead of a plain float, matching the new upstream
    contract introduced in Measure encoder compile time seperate from llm backbone vllm#39240.

  2. vllm_gaudi/ops/hpu_fused_moe.py — Add zero_expert_type and num_logical_experts
    parameters to the HPU override of create_fused_moe_router(), plus ZeroExpertRouter
    dispatch, matching the refactor in [MoE Refactor] Refactor ZeroExpertFusedMoE into new framework vllm#35549.

  3. tests/unit_tests/kv_offload/offloading_connector/test_scheduler.py — Remove
    block_size from OffloadingEvent constructor calls and update assertion, matching
    the field removal in [kv_offload+HMA][3/N]: Remove block_size from KVEvents vllm#36644.

Fixed tests

  • tests/unit_tests/lora/test_llama_tp.py::test_llama_lora
  • tests/unit_tests/lora/test_llm_with_multi_loras.py::test_multiple_lora_requests
  • tests/unit_tests/test_embedding.py::test_embeddings[intfloat/e5-mistral-7b-instruct]
  • tests/unit_tests/ops/test_hpu_fused_moe.py::test_unquantized_fused_moe_method
  • tests/unit_tests/ops/test_hpu_compressed_tensors.py::test_compressed_tensors_wna16_moe_method
  • tests/unit_tests/ops/test_hpu_compressed_tensors.py::test_compressed_tensors_w8a8fp8_block_moe_method
  • tests/unit_tests/kv_offload/offloading_connector/test_scheduler.py::test_offloading_connector[True]
  • tests/unit_tests/kv_offload/offloading_connector/test_scheduler.py::test_offloading_connector[False]

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
…router

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes upstream-compat regressions in the Gaudi vLLM plugin by aligning (1) the HPU worker warmup/compile return type with the new vLLM CompilationTimes contract and (2) the HPU fused-MoE router factory signature/dispatch with upstream’s zero-expert routing refactor.

Changes:

  • Update HPUWorker.compile_or_warm_up_model() to return a CompilationTimes NamedTuple (language_model + encoder) instead of a float.
  • Extend create_fused_moe_router() override to accept zero_expert_type / num_logical_experts and dispatch to ZeroExpertRouter when configured.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
vllm_gaudi/v1/worker/hpu_worker.py Align worker warmup/compile return value with upstream CompilationTimes interface.
vllm_gaudi/ops/hpu_fused_moe.py Add zero-expert router parameters and routing selection to match upstream MoE router factory changes.

Comment thread vllm_gaudi/ops/hpu_fused_moe.py
Comment thread vllm_gaudi/ops/hpu_fused_moe.py
Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
@pawel-olejniczak pawel-olejniczak changed the title [FIX_FOR_VLLM_CUSTOM=fc701c80588c215f84af0b745edcf4d127e276bc] Fix CompilationTimes return type and add zero_expert params to HPU MoE router [FIX_FOR_VLLM_CUSTOM=fc701c80588c215f84af0b745edcf4d127e276bc] Fix upstream regressions in HPU worker, MoE router, and offloading tests Apr 15, 2026
@github-actions
Copy link
Copy Markdown

✅ CI Passed

All checks passed successfully against the following vllm commit:
fc701c80588c215f84af0b745edcf4d127e276bc

@adobrzyn adobrzyn merged commit c8958e6 into vllm-project:main Apr 16, 2026
71 checks passed
yeonsily pushed a commit to yeonsily/vllm-gaudi that referenced this pull request Apr 21, 2026
…stream regressions in HPU worker, MoE router, and offloading tests (vllm-project#1354)

Fix three upstream regressions that break HPU unit tests.

## Changes

1. **`vllm_gaudi/v1/worker/hpu_worker.py`** —
`compile_or_warm_up_model()` now returns
`CompilationTimes` NamedTuple instead of a plain `float`, matching the
new upstream
contract introduced in vllm-project/vllm#39240.

2. **`vllm_gaudi/ops/hpu_fused_moe.py`** — Add `zero_expert_type` and
`num_logical_experts`
parameters to the HPU override of `create_fused_moe_router()`, plus
`ZeroExpertRouter`
dispatch, matching the refactor in
vllm-project/vllm#35549.

3.
**`tests/unit_tests/kv_offload/offloading_connector/test_scheduler.py`**
— Remove
`block_size` from `OffloadingEvent` constructor calls and update
assertion, matching
   the field removal in vllm-project/vllm#36644.

## Fixed tests
- `tests/unit_tests/lora/test_llama_tp.py::test_llama_lora`
-
`tests/unit_tests/lora/test_llm_with_multi_loras.py::test_multiple_lora_requests`
-
`tests/unit_tests/test_embedding.py::test_embeddings[intfloat/e5-mistral-7b-instruct]`
-
`tests/unit_tests/ops/test_hpu_fused_moe.py::test_unquantized_fused_moe_method`
-
`tests/unit_tests/ops/test_hpu_compressed_tensors.py::test_compressed_tensors_wna16_moe_method`
-
`tests/unit_tests/ops/test_hpu_compressed_tensors.py::test_compressed_tensors_w8a8fp8_block_moe_method`
-
`tests/unit_tests/kv_offload/offloading_connector/test_scheduler.py::test_offloading_connector[True]`
-
`tests/unit_tests/kv_offload/offloading_connector/test_scheduler.py::test_offloading_connector[False]`

---------

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
Signed-off-by: Yeonsil Yoon <yeon.sil.yoon@intel.com>
bmyrcha pushed a commit to bmyrcha/vllm-gaudi that referenced this pull request Apr 22, 2026
…stream regressions in HPU worker, MoE router, and offloading tests (vllm-project#1354)

Fix three upstream regressions that break HPU unit tests.

## Changes

1. **`vllm_gaudi/v1/worker/hpu_worker.py`** —
`compile_or_warm_up_model()` now returns
`CompilationTimes` NamedTuple instead of a plain `float`, matching the
new upstream
contract introduced in vllm-project/vllm#39240.

2. **`vllm_gaudi/ops/hpu_fused_moe.py`** — Add `zero_expert_type` and
`num_logical_experts`
parameters to the HPU override of `create_fused_moe_router()`, plus
`ZeroExpertRouter`
dispatch, matching the refactor in
vllm-project/vllm#35549.

3.
**`tests/unit_tests/kv_offload/offloading_connector/test_scheduler.py`**
— Remove
`block_size` from `OffloadingEvent` constructor calls and update
assertion, matching
   the field removal in vllm-project/vllm#36644.

## Fixed tests
- `tests/unit_tests/lora/test_llama_tp.py::test_llama_lora`
-
`tests/unit_tests/lora/test_llm_with_multi_loras.py::test_multiple_lora_requests`
-
`tests/unit_tests/test_embedding.py::test_embeddings[intfloat/e5-mistral-7b-instruct]`
-
`tests/unit_tests/ops/test_hpu_fused_moe.py::test_unquantized_fused_moe_method`
-
`tests/unit_tests/ops/test_hpu_compressed_tensors.py::test_compressed_tensors_wna16_moe_method`
-
`tests/unit_tests/ops/test_hpu_compressed_tensors.py::test_compressed_tensors_w8a8fp8_block_moe_method`
-
`tests/unit_tests/kv_offload/offloading_connector/test_scheduler.py::test_offloading_connector[True]`
-
`tests/unit_tests/kv_offload/offloading_connector/test_scheduler.py::test_offloading_connector[False]`

---------

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
Signed-off-by: bmyrcha <bartosz.myrcha@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants