Skip to content

[FIX_FOR_VLLM_CUSTOM=3975eb6de6ea914b9d7b27fd517e0c971ddeb6fc] Fix upstream breakages: NIXL connector, TpKVTopology rename, MoE refactor, transformers v5#1377

Merged
iboiko-habana merged 11 commits intovllm-project:mainfrom
pawel-olejniczak:fix/vllm-hourly-tpkvtopo-rename
Apr 29, 2026
Merged

[FIX_FOR_VLLM_CUSTOM=3975eb6de6ea914b9d7b27fd517e0c971ddeb6fc] Fix upstream breakages: NIXL connector, TpKVTopology rename, MoE refactor, transformers v5#1377
iboiko-habana merged 11 commits intovllm-project:mainfrom
pawel-olejniczak:fix/vllm-hourly-tpkvtopo-rename

Conversation

@pawel-olejniczak
Copy link
Copy Markdown
Contributor

@pawel-olejniczak pawel-olejniczak commented Apr 21, 2026

Summary

Compatibility fixes for vLLM bump to 3975eb6de6. Addresses breakages from multiple upstream PRs affecting NIXL connectors, MoE runner refactor, offloading tests, Qwen3 MoE models, and transformers v5 upgrade.

Root Cause

  1. NIXL import gate — Upstream PR nixl refactor [2/N]: unify TpKVTopology + HeteroTPTransferConfig into TransferTopology vllm#39529 (commit cc3993b05d) moved NIXL imports to vllm/distributed/nixl_utils.py and changed the platform gate from if not is_rocm() to if is_cuda(). HPU is neither CUDA nor ROCm, so it falls into the else branch → tries rixl._api (ROCm-only) → fails → NixlWrapper = NoneRuntimeError("NIXL is not available").

  2. TpKVTopology rename — Same upstream PR #39529 unified TpKVTopology + HeteroTPTransferConfig into TransferTopology, breaking vllm-gaudi NIXL connector imports.

  3. Offloading tests — Upstream PR [kv_offload+HMA][4/N]: Support sliding window lookup vllm#36645 changed OffloadingManager.lookup() API.

  4. MoE runner refactor — Upstream PR [MoE Refactor] Move the shared/fused expert output sum into MoERunnerBase vllm#35949 (commit 726efe177b) moved reduce logic into MoERunnerBase, removing reduce_results, renaming forward_dispatch_forward_dispatch, forward_entry_forward_entry, _maybe_reduce_output_maybe_reduce_final_output. Follow-up PR moved MoERunnerBase and get_layer_from_name to moe_runner_base.py.

  5. Qwen3 MoESharedFusedMoE returns a combined tensor (not a tuple), and MoE runner now handles TP reduction internally, causing double-reduce in qwen3_moe.py / qwen3_next.py.

  6. Transformers v5 — granite tokenizer — Upstream PR Update to transformers v5 vllm#30566 updated transformers to allow v5. GPT2Tokenizer in v5 now respects add_bos_token=True (silently ignored in v4), causing degenerate outputs and 0.0 GSM8K accuracy on granite models.

  7. Transformers v5.6.x — DeepSeek-V2-Lite tokenizer — In transformers v5.6.x, LlamaTokenizerFast was unified into LlamaTokenizer, which does not apply the ByteLevel BPE decoder declared in tokenizer.json. DeepSeek-V2-Lite-Chat's tokenizer decoding strips all spaces (Ġ chars not converted back), producing garbled output and 0.0 accuracy on GSM8K. Fixed natively in transformers v5.7.0.

Fix

  1. NIXL import patch: Add patch_nixl_utils_for_hpu() in register_utils() to monkey-patch vllm.distributed.nixl_utils — imports from nixl._api instead of rixl._api on HPU. Update hetero_hpu_nixl_connector.py to import from vllm.distributed.nixl_utils instead of hardcoded nixl._api.
  2. TpKVTopology → TransferTopology: Rename in NIXL connector imports and monkey-patches.
  3. Offloading tests: Replace runner.manager.lookup.return_value with connector_scheduler._maximal_prefix_lookup.
  4. MoE refactor: Update imports (MoERunnerBase from moe_runner_base), method names (_forward_dispatch, _forward_entry, _maybe_reduce_final_output), remove dead reduce_results / reduce_output().
  5. Qwen3 MoE: Remove incorrect shared_expert tuple indexing and double TP reduction.
  6. Transformers v5 — granite: Remove hardcoded add_bos_token=True from lm-eval model_args to fix GSM8K accuracy regression.
  7. Transformers v5.6.x — DeepSeek-V2-Lite: Exclude transformers 5.6.* in requirements.txt to prevent installation of versions with broken ByteLevel BPE tokenizer decoding. Verified on Gaudi2: gsm8k accuracy 0.65 (expected 0.66, within tolerance) with transformers 5.7.0.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Updates vllm-gaudi’s KV transfer/offload integration to match upstream vLLM API changes (TpKVTopology → TransferTopology and OffloadingManager.lookup signature change), unblocking import-time failures and fixing broken offloading unit tests.

Changes:

  • Update NIXL connector imports/monkey-patches and topology construction to use TransferTopology.
  • Adjust heterogeneous HPU NIXL connector logic to use new TransferTopology fields/methods.
  • Update offloading scheduler tests to accommodate upstream lookup behavior changes.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
vllm_gaudi/distributed/kv_transfer/kv_connector/v1/hpu_nixl_connector.py Renames topology import and updates the HPU-specific __post_init__ monkey-patch to target TransferTopology.
vllm_gaudi/distributed/kv_transfer/kv_connector/v1/hetero_hpu_nixl_connector.py Replaces TpKVTopology with TransferTopology and updates downstream usages (layout flags, block size ratio).
tests/unit_tests/kv_offload/offloading_connector/utils.py Exposes connector_scheduler on the test runner for updated scheduler test control.
tests/unit_tests/kv_offload/offloading_connector/test_scheduler.py Updates assertions and stubbing strategy to align with upstream offloading lookup changes.

Comment on lines 84 to 88
if self.is_mla and self._cross_layers_blocks:
logger.warning("[HPU] TpKVTopology: overriding false-positive _cross_layers_blocks=True "
logger.warning("[HPU] TransferTopology: overriding false-positive _cross_layers_blocks=True "
"for MLA model. HPU get_kv_cache_shape() returns 3-D tensors, causing "
"the dim-count heuristic to misfire. Forcing _cross_layers_blocks=False.")
self._cross_layers_blocks = False
Copy link

Copilot AI Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This monkey-patch relies on the private attribute _cross_layers_blocks of TransferTopology. Since TransferTopology is coming from upstream vLLM and was recently renamed/refactored, this is brittle and can fail at import time if the internal attribute changes. Prefer using a public API (e.g., cross_layers_blocks if it’s settable) or guard the override with hasattr/getattr so the patch degrades safely when upstream internals change.

Copilot uses AI. Check for mistakes.
Comment on lines 97 to 101
runner.scheduler.reset_prefix_cache()
runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.prepare_store.side_effect = (lambda block_hashes, req_context: generate_store_output([]))
runner.manager.lookup.return_value = 1
runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1
runner.run(decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(0, 1, 2))
Copy link

Copilot AI Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests now stub the scheduler by overwriting _maximal_prefix_lookup directly. Because this is a private implementation detail, it makes the tests brittle to upstream refactors and bypasses coverage of the new OffloadingManager.lookup(key: OffloadKey) -> bool behavior. Consider mocking manager.lookup with a side_effect that returns True/False based on the provided OffloadKey (or providing a small helper/fake manager) so the test exercises the real prefix-lookup logic.

Copilot uses AI. Check for mistakes.
runner.run(decoded_tokens=[EOS_TOKEN_ID])
runner.manager.lookup.assert_called()
assert len(list(runner.manager.lookup.call_args.args[0])) == 1
runner.manager.lookup.assert_called_once()
Copy link

Copilot AI Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion was weakened from validating the lookup input (previously checking the iterable length) to only checking lookup() was called once. To keep coverage of the signature change, consider asserting that lookup was called with an OffloadKey (and, if applicable, that req_context is passed via args/kwargs) so the test will fail if the old iterable-based call is accidentally reintroduced.

Suggested change
runner.manager.lookup.assert_called_once()
runner.manager.lookup.assert_called_once()
lookup_args, lookup_kwargs = runner.manager.lookup.call_args
assert lookup_args
assert isinstance(lookup_args[0], OffloadKey)
if "req_context" in lookup_kwargs:
assert lookup_kwargs["req_context"] is not None
else:
assert len(lookup_args) >= 2
assert lookup_args[1] is not None

Copilot uses AI. Check for mistakes.
Comment on lines +596 to 600
self.transfer_topo = TransferTopology(
tp_rank=self.tp_rank,
tp_size=self.world_size,
block_size=self.block_size,
engine_id=self.engine_id,
Copy link

Copilot AI Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TransferTopology is initialized with block_size=self.block_size here, but self.block_size can be mutated later in register_kv_caches() when the physical kernel block size differs (and self._block_size[self.engine_id] is updated). If TransferTopology.block_size_ratio() uses its stored local block_size, it can become stale and produce incorrect ratios/mappings during transfers. Consider constructing TransferTopology after the final effective block size is known, or updating/recreating self.transfer_topo whenever self.block_size is adjusted.

Copilot uses AI. Check for mistakes.
@github-actions
Copy link
Copy Markdown

🚧 CI Blocked

The main CI workflow was not started for the following reason:

Your branch is behind the base branch. Please merge or rebase to get the latest changes.

@pawel-olejniczak pawel-olejniczak force-pushed the fix/vllm-hourly-tpkvtopo-rename branch 3 times, most recently from 8f94303 to f4fdf89 Compare April 22, 2026 21:27
@pawel-olejniczak pawel-olejniczak force-pushed the fix/vllm-hourly-tpkvtopo-rename branch 2 times, most recently from a1a7df4 to 35e2767 Compare April 23, 2026 09:33
@github-actions
Copy link
Copy Markdown

🚧 CI Blocked

The main CI workflow was not started for the following reason:

Your branch is behind the base branch. Please merge or rebase to get the latest changes.

@pawel-olejniczak pawel-olejniczak force-pushed the fix/vllm-hourly-tpkvtopo-rename branch from 35e2767 to 216833d Compare April 23, 2026 09:43
@pawel-olejniczak pawel-olejniczak changed the title [FIX_FOR_VLLM_CUSTOM=b47840019e61a3983c8144066a99c843d177947d] Fix TpKVTopology rename to TransferTopology and update offloading tests [FIX_FOR_VLLM_CUSTOM=b47840019e61a3983c8144066a99c843d177947d] Fix upstream breakages: TpKVTopology rename, MoE runner refactor, Qwen3 MoE fixes Apr 23, 2026
@pawel-olejniczak pawel-olejniczak force-pushed the fix/vllm-hourly-tpkvtopo-rename branch 5 times, most recently from ee8a974 to 6a3bedd Compare April 25, 2026 11:31
@pawel-olejniczak pawel-olejniczak changed the title [FIX_FOR_VLLM_CUSTOM=b47840019e61a3983c8144066a99c843d177947d] Fix upstream breakages: TpKVTopology rename, MoE runner refactor, Qwen3 MoE fixes [FIX_FOR_VLLM_CUSTOM=3975eb6de6ea914b9d7b27fd517e0c971ddeb6fc] Fix upstream breakages: TpKVTopology rename, MoE runner refactor, Qwen3 MoE fixes Apr 26, 2026
@pawel-olejniczak pawel-olejniczak force-pushed the fix/vllm-hourly-tpkvtopo-rename branch from 9bd6588 to 867e244 Compare April 26, 2026 19:26
@pawel-olejniczak pawel-olejniczak force-pushed the fix/vllm-hourly-tpkvtopo-rename branch 2 times, most recently from 9ef636f to 4604d76 Compare April 27, 2026 13:57
@pawel-olejniczak pawel-olejniczak changed the title [FIX_FOR_VLLM_CUSTOM=3975eb6de6ea914b9d7b27fd517e0c971ddeb6fc] Fix upstream breakages: TpKVTopology rename, MoE runner refactor, Qwen3 MoE fixes [FIX_FOR_VLLM_CUSTOM=3975eb6de6ea914b9d7b27fd517e0c971ddeb6fc] Fix upstream breakages: NIXL connector, TpKVTopology rename, MoE refactor, transformers v5 Apr 27, 2026
@pawel-olejniczak pawel-olejniczak force-pushed the fix/vllm-hourly-tpkvtopo-rename branch from 8f775ef to 5b2f6c1 Compare April 27, 2026 20:12
@github-actions
Copy link
Copy Markdown

🚧 CI Blocked

The main CI workflow was not started for the following reason:

Your branch is behind the base branch. Please merge or rebase to get the latest changes.

…KVTopology rename to TransferTopology and update offloading tests

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
…and forward_dispatch rename (upstream PR #35949)

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
…(forward_entry, maybe_reduce_output)

Upstream PR #35949 (commit 726efe177b) renamed:
- forward_entry -> _forward_entry (now stored as private attribute)
- _maybe_reduce_output -> _maybe_reduce_final_output

Update patched_fused_moe_forward in vllm_gaudi/ops/hpu_fused_moe.py to use
the new method names. Without this, DP MoE paths and several e2e tests fail
with AttributeError on DefaultMoERunner.

Follow-up to the previous fix in this branch which addressed reduce_results
removal and forward_dispatch rename from the same upstream PR.

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
…ve double TP reduction in MoE models

- Update hpu_fused_moe.py imports: DefaultMoERunner/moe_runner_base removed
  in upstream PR #40560 (combined into MoERunner in moe_runner.py)
- Remove incorrect shared_expert tuple indexing in qwen3_next.py
  (SharedFusedMoE returns combined tensor, not tuple)
- Remove double tensor_model_parallel_all_reduce in qwen3_moe.py and
  qwen3_next.py (MoE runner handles TP reduction internally after PR #35949)

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
…Base from moe_runner_base

At vLLM b47840019e, MoERunnerBase and get_layer_from_name live in
moe_runner_base.py (not moe_runner.py). MoERunner in moe_runner.py is
just an ABC without __init__ or forward. Fix imports and monkey-patch
target to use MoERunnerBase directly.

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
Align to the vllm change (PR #30566)

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
This reverts commit 9148d75.
Since the vllm 3975eb6de6 is reverted, this no longer needed.

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
Remove hardcoded add_bos_token=True from lm-eval model_args.

In transformers v4, GPT2TokenizerFast silently ignored add_bos_token=True.
In transformers v5, GPT2Tokenizer respects it and prepends the BOS token
(id=0, <|end_of_text|> for granite) to every prompt. This causes the model
to see an end-of-text signal at position 0, producing degenerate outputs
(repetition loops) and 0.0 accuracy on GSM8K CoT benchmarks.

The parameter was effectively a no-op with v4 and is not needed — granite
tokenizer has add_bos_token=False by default.

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
…l._api

Upstream vLLM commit 3975eb6de6 gates NIXL imports on is_cuda(), falling back to rixl._api for non-CUDA platforms. HPU needs nixl._api (same as CUDA). This adds a monkey-patch in register_utils() and fixes the hetero_hpu_nixl_connector to import from vllm.distributed.nixl_utils instead of hardcoding nixl._api.

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
In transformers v5.6.x, LlamaTokenizerFast was unified into LlamaTokenizer
which does not apply the ByteLevel BPE decoder declared in tokenizer.json.
This causes DeepSeek-V2-Lite-Chat decoding to strip all spaces (Ġ chars not
converted back), producing garbled output and 0.0 accuracy on GSM8K.

transformers v5.7.0 resolves this by returning TokenizersBackend from
AutoTokenizer.from_pretrained(), which correctly applies the ByteLevel
decoder. Exclude 5.6.* in requirements.txt to prevent installation of
the broken versions.

Verified on Gaudi2 pod: gsm8k accuracy 0.65 (expected 0.66, within tolerance)
with transformers 5.7.0 and no code changes.

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
@pawel-olejniczak pawel-olejniczak force-pushed the fix/vllm-hourly-tpkvtopo-rename branch from 84a0171 to 7e937d0 Compare April 28, 2026 19:32
@github-actions
Copy link
Copy Markdown

✅ CI Passed

All checks passed successfully against the following vllm commit:
3975eb6de6ea914b9d7b27fd517e0c971ddeb6fc

@iboiko-habana iboiko-habana merged commit c2c370f into vllm-project:main Apr 29, 2026
71 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants