Skip to content

[Spec Decode] Support hybrid attention models in extract_hidden_states#39949

Merged
vllm-bot merged 11 commits into
vllm-project:mainfrom
neuralmagic:extract-hidden-states-hybrid
May 13, 2026
Merged

[Spec Decode] Support hybrid attention models in extract_hidden_states#39949
vllm-bot merged 11 commits into
vllm-project:mainfrom
neuralmagic:extract-hidden-states-hybrid

Conversation

@mgoin
Copy link
Copy Markdown
Member

@mgoin mgoin commented Apr 15, 2026

Summary

Hidden-state extraction now works on hybrid-attention models (e.g. Qwen3.5). The kv-transfer config no longer force-disables HMA, it stays on for connectors that declare SupportsHMA.

The KV-cache grouping handles the cache-only hidden-state layer alongside a hybrid attention/Mamba layout by filtering it out before unification and adding it back as its own page-aligned 1-layer group, with a strided reshape in the model runner to span the padded page.

ExampleHiddenStatesConnector is updated accordingly, and a new CI job runs the integration tests.

Test plan

  • tests/v1/kv_connector/extract_hidden_states_integration/test_extraction.py — Llama end-to-end (GPU)
  • Qwen3.5-4B + extract_hidden_states — hybrid model end-to-end (GPU), hidden states shape [N, 3, 2560] with non-zero values
  • pre-commit run ruff-check / ruff-format / mypy-3.10 — all passing
  • CI

🤖 AI-assisted (Claude)

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the HiddenStateCacheSpec to support hidden-state extraction within the vLLM V1 engine. Key changes include updating the KV cache grouping heuristics to prevent singleton cache-only layers from collapsing group sizes, refactoring the ExampleHiddenStatesConnector to utilize attn_metadata.slot_mapping directly, and implementing dynamic HMA (Hybrid Memory Architecture) support checks for connectors. Feedback is provided regarding the max_memory_usage_bytes implementation in HiddenStateCacheSpec, which currently fails to account for context parallelism, potentially leading to memory over-estimation during initialization.

Comment thread vllm/v1/kv_cache_interface.py Outdated
Hidden-state extraction breaks on hybrid-attention models (e.g.
Qwen3.5) because kv_transfer_config force-disables HMA and
unify_hybrid_kv_cache_specs cannot fold MambaSpec into a uniform type.

Fix by gating HMA-disable on supports_hma(connector_cls), making
ExampleHiddenStatesConnector a SupportsHMA subclass, and handling the
cache-only layer's page alignment for hybrid models. Key changes:

- HiddenStateCacheSpec: thin marker subclass of MLAAttentionSpec
  (inherits all dispatch behavior, no overrides). Defined in
  kv_cache_interface.py, registered in spec_manager_map.
- get_kv_cache_groups: filter HiddenStateCacheSpec out before
  unify/grouping, add back as 1-layer group with page_size_padded
  aligned to the common page. General sub-functions untouched.
- gpu_model_runner: as_strided reshape branch for padded specs
  (page_size_padded > real_page), proposer isinstance for kv_cache_gid.
- Connector: read slot_mapping from attn_metadata (not scheduler
  block_ids), remove dead ReqMeta.slot_mapping field.
- Proposer: kv_cache_gid for correct common_attn_metadata selection.
- basic_cache/extract_from_kv_cache: block/offset indexing instead of
  flatten (works on non-contiguous strided tensors).

Verified: Llama integration test + Qwen3.5-4B end-to-end on GPU.

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@mgoin mgoin force-pushed the extract-hidden-states-hybrid branch from 12019e0 to 530539a Compare April 16, 2026 16:43
@mgoin mgoin requested a review from xuechendi as a code owner April 16, 2026 16:43
Copy link
Copy Markdown
Member

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides the other included comments, my main concern here is that we're adding a very manual specialization for this extractor. This creates deviation from existing standard practices (native padding e.g. in mamba) and requires us to maintain the codepath as the HMA changes. This seems quite fragile, especially since bugs in the hidden extractor would not manifest during routine debugging / correctness testing.

I'm not convinced that it isn't feasible to simply apply padding to the hiddenstatecachespec (or to the other blocks, whichever is smaller, either case should be fine) and handle that natively: either in the kv connector itself or in the client code.

Specifically, I have problems with the code in gpu_model_runner.py which manually strides the view to avoid the padding. This seems to break some conventions about how we handle the KV connectors (which, AFAICT, seem to keep the padding?). Also, the special handling of the hidden state group in vllm/v1/core/kv_cache_utils.py seems fragile and prone to issues.

Comment thread vllm/v1/spec_decode/extract_hidden_states.py Outdated
Comment thread vllm/v1/core/kv_cache_utils.py Outdated
@johnnynunez
Copy link
Copy Markdown
Contributor

this is good @mgoin... seems to enable nemotron too?

johnnynunez added a commit to johnnynunez/vllm that referenced this pull request Apr 28, 2026
… into nemotron-eagle3-support

Brings in @mgoin's PR vllm-project#39949 (vllm-project/vllm) which lets the
extract_hidden_states speculative method work on hybrid attention
backbones (Mamba-2 / GatedDeltaNet + attention).

Why pull this in early:

* Validated end-to-end against nvidia/NVIDIA-Nemotron-3-Nano-4B-BF16
  on a DGX Spark (GB10): vLLM serve OK, /health 200, speculators
  online DFlash training runs ~150 real loss steps before tuning-related
  NaN. Without vllm-project#39949, vLLM cannot even start the verifier
  (NotImplementedError on page-size unification).
* Composes cleanly with this branch's NemotronH SupportsEagle3 hooks
  (5925cca..f92ca38) -- no overlap in files, single conflict was
  in gpu_model_runner.py only.

Conflict resolution:

* vllm/v1/worker/gpu_model_runner.py -- kept the upstream/main version
  of the as_strided branch in _reshape_kv_cache_tensors. Upstream
  refactored this path to gate on `kv_cache_spec.page_size_padded is
  not None` and to allocate strides via `torch.empty(kv_cache_shape).
  stride()`, which is functionally equivalent to the PR's
  `> real_page_size_bytes` check + manual `inner.stride()` on this
  spec but more general (already in upstream main, post-dates the PR).
  All other bits of the PR (HiddenStateCacheSpec marker class,
  SupportsHMA on ExampleHiddenStatesConnector, kv_cache_utils group
  filter+add-back, attn_metadata.slot_mapping, proposer kv_cache_gid)
  apply unchanged.

Verified post-merge:

  python ast.parse on all 8 changed files -> OK
  HiddenStateCacheSpec class present at vllm/v1/kv_cache_interface.py:376
  ExampleHiddenStatesConnector(KVConnectorBase_V1, SupportsHMA) confirmed

If/when vllm-project#39949 lands upstream, the merge will resolve to a no-op for
this branch; nothing here forks the PR.

Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: Johnny Nunez <johnnynuca14@gmail.com>
Made-with: Cursor

Signed-off-by: johnnynunez <johnnynuca14@gmail.com>
@benchislett
Copy link
Copy Markdown
Member

I'm pretty sure the manual striding was merged into gpu model runner for DSV4. Assuming we can piggy-back on that, I think we can get this in without much change. Can you merge main?

This PR's vllm/config/vllm.py change replaced the unconditional HMA
force-disable for kv_transfer_config with a per-connector supports_hma
check. Two test patterns started hitting the factory's "HMA enabled
but connector doesn't support it" raise:

1. Tests that built VllmConfig with default Nixl (supports_hma) and
   then mutated kv_transfer_config to a non-HMA connector. By that
   point __post_init__ had already set disable_hybrid_kv_cache_manager
   based on Nixl, so the factory raised when the actual connector was
   built.

2. MultiConnector subclasses SupportsHMA at the class level, so
   __post_init__ left HMA enabled, but MultiConnector.__init__ asserts
   every sub-connector supports HMA at runtime.

Fixes:

- vllm/config/vllm.py: when the connector is MultiConnector, recurse
  into kv_connector_extra_config["connectors"] and AND-fold their
  supports_hma so the auto-disable matches the runtime assertion.
  Unblocks test_multi_connector_mixed_hma_disables_hybrid_kv_cache,
  which explicitly verifies this behavior.

- tests/v1/kv_connector/unit/utils.py: add kv_connector_module_path
  parameter to create_vllm_config so tests can build the config with
  the actual external connector class instead of mutating after
  construction.

- 4 unit tests: switch from post-construction kv_transfer_config
  mutation to construction-time create_vllm_config arguments.
  __post_init__ then resolves the real connector class and
  auto-disable kicks in correctly. No disable_hybrid_kv_cache_manager
  overrides in test code.

- .buildkite/test_areas/misc.yaml: add "V1 Extract Hidden States
  Integration" job covering tests/v1/kv_connector/extract_hidden_states_integration
  (predictable-Llama and Qwen3.5-0.8B hybrid smoke tests).

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@mgoin mgoin force-pushed the extract-hidden-states-hybrid branch from 5e2b47d to 7a76dd7 Compare May 8, 2026 19:16
mgoin and others added 3 commits May 8, 2026 15:20
The autouse fixture registers PredictableLlamaForCausalLM in the parent
pytest process via ModelRegistry.register_model. CI sets
VLLM_WORKER_MULTIPROC_METHOD=spawn, which starts the engine worker as a
fresh Python process that doesn't inherit the registration, so the
worker fails with "Model architectures ['PredictableLlamaForCausalLM']
are not supported".

Override to fork inside the test so the worker inherits the parent's
ModelRegistry state. The Qwen3.5 hybrid smoke test in the same file
doesn't need this override (it uses a real registered architecture).

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The HMA auto-disable resolves the kv connector class via
KVConnectorFactory.get_connector_class(self.kv_transfer_config), which
raises if kv_connector is None. The kv_offloading_backend path leaves
kv_connector unset until _post_init_kv_transfer_config() populates it
(e.g. "OffloadingConnector", "LMCacheConnectorV1"), so callers using
--kv-offloading-backend without --kv-connector hit a ValidationError.

Move _post_init_kv_transfer_config() to run before the HMA block. This
also fixes a latent issue where the earlier cudagraph-mode connector
inspection would see a stale user-provided kv_connector before
kv_offloading overwrote it.

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@mgoin mgoin requested a review from Harry-Chen as a code owner May 9, 2026 17:12
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 9, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mgoin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 9, 2026
mgoin added 2 commits May 12, 2026 10:34
Resolved two conflicts:
- vllm/config/vllm.py: kept early _post_init_kv_transfer_config(),
  added _verify_kv_transfer_compat() from main
- test_backwards_compatibility.py: accepted main's deletion (compat removed)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: mgoin <mgoin64@gmail.com>
@mergify mergify Bot removed the needs-rebase label May 12, 2026
@vllm-bot vllm-bot merged commit 2f821fa into vllm-project:main May 13, 2026
91 of 94 checks passed
@DarkLight1337
Copy link
Copy Markdown
Member

DarkLight1337 commented May 14, 2026

I'm trying to replicate the setup for training a model like https://huggingface.co/z-lab/Qwen3.5-27B-DFlash and the memory usage is quite unreasonable on AMD, it tells me I need over 200 GB VRAM for a 27B model...

Docker:

services:
  vllm:
    image: vllm/vllm-openai-rocm:nightly-bf0d2dc6d764f7ab1a69504f60a55883ec6d9b39
    container_name: vllm
    command: |
      --model Qwen/Qwen3.5-27B --enable-prefix-caching --limit-mm-per-prompt '{"image": 1, "video": 0}' --reasoning-parser qwen3 --default-chat-template-kwargs '{"enable_thinking": false}' --allowed-local-media-path / --log-error-stack --speculative_config '{"method": "extract_hidden_states", "num_speculative_tokens": 1, "draft_model_config": {"hf_config": {"eagle_aux_hidden_state_layer_ids": [1, 16, 31, 46, 61]}}}' --kv_transfer_config '{"kv_connector": "ExampleHiddenStatesConnector", "kv_role": "kv_producer", "kv_connector_extra_config": {"shared_storage_path": "/tmp/hidden_states"}}'
    environment:
      - OPENAI_API_KEY=${OPENAI_API_KEY}
      - HF_TOKEN=${HF_TOKEN}
    network_mode: host
    group_add:
      - video
    ipc: host
    cap_add:
      - SYS_PTRACE
    security_opt:
      - seccomp=unconfined
    devices:
      - /dev/kfd
      - /dev/dri
    ports:
      - 8000:8000
    restart: unless-stopped
    volumes:
      - type: bind
        source: ~/.cache/huggingface/hub/
        target: /root/.cache/huggingface/hub/

Error message:

ValueError: To serve at least one request with the models's max seq len (262144), (204.06 GiB KV cache is needed, which is larger than the available KV cache memory (158.0 GiB). Based on the available memory, the estimated maximum model length is 202952. Try increasing `gpu_memory_utilization` or decreasing `max_model_len` when initializing the engine. See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ for more details.

Things I tried but didn't help:

  • Setting --max-num-seqs 1
  • Disabling prefix caching (this is required regardless of OOM because of block size assertion in HybridKVCacheCoordinator)
  • Eager mode

Copy link
Copy Markdown
Member

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes here went a bit too deep with the MultiConnector bit @mgoin :(

Comment thread vllm/config/vllm.py
Comment on lines +1371 to +1395
all_support_hma = supports_hma(connector_cls)
# MultiConnector subclasses SupportsHMA; only effectively
# supports HMA when every sub-connector does.
if all_support_hma and connector_cls.__name__ == "MultiConnector":
sub_ktcs = self.kv_transfer_config.kv_connector_extra_config.get(
"connectors", []
)
all_support_hma = all(
supports_hma(
KVConnectorFactory.get_connector_class(
KVTransferConfig(**sub)
)
)
for sub in sub_ktcs
)
if not all_support_hma:
need_disable_hybrid_kv_cache_manager = True
logger.warning(
"Turning off hybrid kv cache manager because "
"connector %s does not subclass `SupportsHMA`. "
"This will reduce performance on models with "
"sliding window or Mamba attention. See "
"kv_connector/v1/base.py for details.",
connector_cls.__name__,
)
Copy link
Copy Markdown
Member

@NickLucche NickLucche May 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we were tracking this change here #41847 and #42024 !

@DarkLight1337
Copy link
Copy Markdown
Member

DarkLight1337 commented May 14, 2026

I worked around the OOM by setting a lower value of --max-model-len. But now I'm getting some hidden states mismatch issues when running offline data generation in vllm-project/speculators#495 (I set --validate-outputs so the input tokens should be correct), e.g.:

WARNING  Skipping sample 17 due to error: Sequence length of hidden states 390 in output/speculators/hidden_states/hs_17.safetensors doesn't match num tokens 496
WARNING  Skipping sample 458 due to error: Sequence length of hidden states 625 in output/speculators/hidden_states/hs_458.safetensors doesn't match num tokens 3888
WARNING  Skipping sample 3460 due to error: Sequence length of hidden states 133 in output/speculators/hidden_states/hs_3460.safetensors doesn't match num tokens 449
WARNING  Skipping sample 3555 due to error: Sequence length of hidden states 401 in output/speculators/hidden_states/hs_3555.safetensors doesn't match num tokens 433
WARNING  Skipping sample 3968 due to error: Sequence length of hidden states 35 in output/speculators/hidden_states/hs_3968.safetensors doesn't match num tokens 397
WARNING  Skipping sample 4252 due to error: Sequence length of hidden states 48 in output/speculators/hidden_states/hs_4252.safetensors doesn't match num tokens 271
WARNING  Skipping sample 4333 due to error: Sequence length of hidden states 184 in output/speculators/hidden_states/hs_4333.safetensors doesn't match num tokens 488
WARNING  Skipping sample 4399 due to error: Sequence length of hidden states 110 in output/speculators/hidden_states/hs_4399.safetensors doesn't match num tokens 798
WARNING  Skipping sample 5365 due to error: Sequence length of hidden states 127 in output/speculators/hidden_states/hs_5365.safetensors doesn't match num tokens 343

This only happens roughly once every 500 samples, and the sample numbers that are failing are not deterministic. This leads me to think that it's a race condition of some sort.

@DarkLight1337
Copy link
Copy Markdown
Member

DarkLight1337 commented May 14, 2026

Ok this happens with Qwen/Qwen3-VL-30B-A3B-Instruct as well, so maybe it is a problem for multimodal models in general rather than hybrid models. Turning off CUDA graph and async scheduling didn't help either.

@DarkLight1337
Copy link
Copy Markdown
Member

cc @fynnsu

@fynnsu
Copy link
Copy Markdown
Contributor

fynnsu commented May 14, 2026

Hmm strange that the hidden state shape is less than the num tokens. I will try to repro this.

@DarkLight1337
Copy link
Copy Markdown
Member

If I set --concurrency 1 in the offline generation script then the problem goes away.

@fynnsu
Copy link
Copy Markdown
Contributor

fynnsu commented May 15, 2026

@DarkLight1337 Okay, I was able to reproduce the issue with Qwen/Qwen3-VL-30B-A3B-Instruct. I actually still see the problem when using --concurrency 1.

Claude seems to think its an issue with chunked prefill, and I'm trying to develop a fix. Unfortunately this model doesn't seem to work if you just --no-enable-chunked-prefill. I'll keep working on this, and let you know when i have an update.

@fynnsu
Copy link
Copy Markdown
Contributor

fynnsu commented May 15, 2026

This branch has a WIP fix: main...fynnsu:vllm:fix_chunked_prefill_hs_connector

That seems to solve the problem but it also makes the save happen after the request finishes so we need to add handling in speculators for this. (To test this I just temporarily hardcoded in a 2-second wait before reading the file in data_generation_offline.py).

#37374 also enforces a single save (by requiring no-chunked-prefill) and adds async lock file logic so that the downstream client know when it's safe to read the file. I think we can build a proper fix off of that pr. This should also get around the no-chunked-prefill requirement because instead we just save all the blocks when the request finishes which allows chunked prefill to run as normal.

@DarkLight1337
Copy link
Copy Markdown
Member

DarkLight1337 commented May 18, 2026

This branch has a WIP fix: main...fynnsu:vllm:fix_chunked_prefill_hs_connector

This also works on my end, thanks for the quick fix! Let's also put this in #37374.

To test this I just temporarily hardcoded in a 2-second wait before reading the file in data_generation_offline.py

We could add a check in generate_hidden_states_* on speculators side to validate that vLLM has actually finished saving the file before reading it. To avoid TOCTOU issues we could use a sentinel file approach:

# vllm
safetensors.torch.save_file(tensors, pending.filename)
finished_sending.add(pending.req_id)
Path(f"{pending.filename}.done").touch()

# speculators
hs_filepath = extract_output(res, token_ids)
done_path = Path(f"{hs_filepath}.done")

while not done_path.exists():
    await asyncio.sleep(1)

done_path.unlink()
return hs_filepath

mfylcek pushed a commit to mfylcek/vllm that referenced this pull request May 19, 2026
vllm-project#39949)

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
rishitdholakia13 pushed a commit to rishitdholakia13/vllm that referenced this pull request May 19, 2026
vllm-project#39949)

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
jhu960213 pushed a commit to jhu960213/vllm that referenced this pull request May 20, 2026
vllm-project#39949)

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
h1t35h pushed a commit to h1t35h/vllm that referenced this pull request May 21, 2026
vllm-project#39949)

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
learning-sketch added a commit to learning-sketch/vllm-ascend that referenced this pull request May 29, 2026
… padding with Ascend model runner

### What this PR does / why we need it?

`AscendExtractHiddenStatesProposer` inherits the upstream
`ExtractHiddenStatesProposer._determine_batch_execution_and_padding`
unchanged, which on Ascend causes two distinct failures when running
`extract_hidden_states` on a MoE target model with DP > 1 and
sequence parallelism enabled (e.g. MiniMax-M2 with
`VLLM_ASCEND_ENABLE_FLASHCOMM1=1`):

1. **gloo shape mismatch on the DP cpu_group**:

       what(): [enforce fail at .../gloo/transport/tcp/pair.cc:456]
       op.preamble.length <= op.nbytes. 8 vs 4.
       Received data size doesn't match expected size.
       Is there a distributed collective mismatch in your code?

   Upstream `coordinate_batch_across_dp` posts a `[4, dp_size]` int32
   tensor to the DP cpu_group, while Ascend's main runner uses
   `_sync_metadata_across_dp` with a `[2, dp_size]` tensor on the
   same cpu_group. The two shapes collide within one step.

2. **reduce_scatter shape-not-divisible assertion on the idle DP rank**:

       File ".../vllm_ascend/ops/linear_op.py", line 574, in matmul_and_reduce
           output = tensor_model_parallel_reduce_scatter(output_parallel, 0)
       File ".../base_device_communicator.py", line 234, in reduce_scatter
           assert input_tensor.shape[0] % world_size == 0
       AssertionError

   The proposer's own `cudagraph_dispatcher` is initialized as
   PIECEWISE/NONE only (never FULL), so `dispatch(num_tokens=6)`
   returns 6 as-is (no SP padding). That 6 enters DP sync, the synced
   max stays 6, and the idle DP rank's main MoE forward then crashes
   in SP reduce_scatter because `6 % TP=4 != 0`.

   Eagle3/MTP do not reproduce this because Ascend's `AscendEagleProposer`
   uses `runner.cudagraph_dispatcher.dispatch(...)` which dispatches
   against the runner's FULL-mode capture sizes (always TP-aligned).

### Fix

Override `AscendExtractHiddenStatesProposer._determine_batch_execution_and_padding`:

1. SP-pad `num_tokens` via `runner._pad_for_sequence_parallelism`
   before dispatch, so the contribution to DP sync is always
   TP-aligned. Mirrors what the runner's main path does at
   `model_runner_v1.py:_determine_batch_execution_and_padding`.

2. Use `runner._sync_metadata_across_dp` (packed_tensor shape
   `[2, dp_size]`) for DP coordination instead of upstream
   `coordinate_batch_across_dp` (shape `[4, dp_size]`), so all DP
   collectives in a single step that hit the cpu_group use a
   consistent tensor shape.

3. Fail fast at the entry of the override with a clear `AssertionError`
   if the proposer was constructed without a `runner` reference,
   instead of letting the unguarded `runner._pad_for_sequence_parallelism`
   call raise a confusing `AttributeError`.

4. Document the `is_draft_model=True` semantics: it intentionally
   makes `should_skip_allreduce_across_dp_group` short-circuit (the
   cache-only "draft" here is not MoE), so the call degenerates to a
   local broadcast. The actual cross-DP all_reduce has already been
   done by the main runner earlier in the step; the SP padding above
   is what keeps the value TP-aligned regardless.

### Does this PR introduce _any_ user-facing change?

No user-facing API change. Fixes a runtime crash for users running
`extract_hidden_states` speculative decoding with
`--data-parallel-size > 1` on MoE target models on Ascend NPU.
Single-DP runs and dense target models (e.g. Qwen3-8B) are unaffected.

### How was this patch tested?

Reproduced the crash and verified the fix on MiniMax-M2:

- 2x8 NPU, `--tensor-parallel-size 4 --data-parallel-size 2`
- `--enable-expert-parallel`
- `--compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY"}'`
- `--speculative-config '{"method": "extract_hidden_states", "num_speculative_tokens": 1, "draft_model_config": {"hf_config": {"eagle_aux_hidden_state_layer_ids": [2, 18, 34]}}}'`
- `--kv-transfer-config '{"kv_connector": "ExampleHiddenStatesConnector", "kv_role": "kv_producer", ...}'`
- Sending one `/v1/completions` request with `max_tokens=1`:
    - Before fix: idle DP rank crashes on first `execute_dummy_batch`
      with the `AssertionError` shown above.
    - After fix: request returns 200 OK, hidden_states `.safetensors`
      file is written with the expected
      `(prompt_len, len(layer_ids), hidden_size)` shape.

Also verified the existing dense Qwen3-8B + extract_hidden_states path
still works unchanged.

Unit tests added in
`tests/ut/spec_decode/test_extract_hidden_states_proposer.py`:

- `test_determine_batch_execution_and_padding_asserts_when_runner_is_none`:
  regression guard for the `AttributeError` that would otherwise be
  raised on the unguarded `self.runner._pad_for_sequence_parallelism`
  call at the entry of the override.
- `test_determine_batch_execution_and_padding_dp1_sp_pads_and_skips_sync`:
  with DP=1 the runner's `_pad_for_sequence_parallelism` is still
  consulted (so cache_only forward gets an SP-aligned input) but
  `_sync_metadata_across_dp` is not called.
- `test_determine_batch_execution_and_padding_dp2_uses_runner_sync`:
  with DP>1 the override calls `runner._sync_metadata_across_dp`
  with the SP-padded `num_tokens` and `is_draft_model=True`, and does
  NOT call the upstream `coordinate_batch_across_dp` (regression
  guard for the gloo `8 vs 4` shape mismatch).
- `test_determine_batch_execution_and_padding_dp2_keeps_tp_aligned_for_main_forward`:
  regression guard for the `reduce_scatter`
  `input.shape[0] % world_size == 0` assertion 闁?the final
  `num_tokens_padded` returned to the caller is always TP-aligned.

Related upstream PR: vllm-project/vllm#39949 (introduced
extract_hidden_states speculative method).

Signed-off-by: learning-sketch <learning-sketch@users.noreply.github.com>
mvanhorn pushed a commit to mvanhorn/vllm that referenced this pull request Jun 4, 2026
vllm-project#39949)

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build kv-connector ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants