[Spec Decode] Support hybrid attention models in extract_hidden_states#39949
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the HiddenStateCacheSpec to support hidden-state extraction within the vLLM V1 engine. Key changes include updating the KV cache grouping heuristics to prevent singleton cache-only layers from collapsing group sizes, refactoring the ExampleHiddenStatesConnector to utilize attn_metadata.slot_mapping directly, and implementing dynamic HMA (Hybrid Memory Architecture) support checks for connectors. Feedback is provided regarding the max_memory_usage_bytes implementation in HiddenStateCacheSpec, which currently fails to account for context parallelism, potentially leading to memory over-estimation during initialization.
Hidden-state extraction breaks on hybrid-attention models (e.g. Qwen3.5) because kv_transfer_config force-disables HMA and unify_hybrid_kv_cache_specs cannot fold MambaSpec into a uniform type. Fix by gating HMA-disable on supports_hma(connector_cls), making ExampleHiddenStatesConnector a SupportsHMA subclass, and handling the cache-only layer's page alignment for hybrid models. Key changes: - HiddenStateCacheSpec: thin marker subclass of MLAAttentionSpec (inherits all dispatch behavior, no overrides). Defined in kv_cache_interface.py, registered in spec_manager_map. - get_kv_cache_groups: filter HiddenStateCacheSpec out before unify/grouping, add back as 1-layer group with page_size_padded aligned to the common page. General sub-functions untouched. - gpu_model_runner: as_strided reshape branch for padded specs (page_size_padded > real_page), proposer isinstance for kv_cache_gid. - Connector: read slot_mapping from attn_metadata (not scheduler block_ids), remove dead ReqMeta.slot_mapping field. - Proposer: kv_cache_gid for correct common_attn_metadata selection. - basic_cache/extract_from_kv_cache: block/offset indexing instead of flatten (works on non-contiguous strided tensors). Verified: Llama integration test + Qwen3.5-4B end-to-end on GPU. Signed-off-by: mgoin <mgoin64@gmail.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
12019e0 to
530539a
Compare
benchislett
left a comment
There was a problem hiding this comment.
Besides the other included comments, my main concern here is that we're adding a very manual specialization for this extractor. This creates deviation from existing standard practices (native padding e.g. in mamba) and requires us to maintain the codepath as the HMA changes. This seems quite fragile, especially since bugs in the hidden extractor would not manifest during routine debugging / correctness testing.
I'm not convinced that it isn't feasible to simply apply padding to the hiddenstatecachespec (or to the other blocks, whichever is smaller, either case should be fine) and handle that natively: either in the kv connector itself or in the client code.
Specifically, I have problems with the code in gpu_model_runner.py which manually strides the view to avoid the padding. This seems to break some conventions about how we handle the KV connectors (which, AFAICT, seem to keep the padding?). Also, the special handling of the hidden state group in vllm/v1/core/kv_cache_utils.py seems fragile and prone to issues.
|
this is good @mgoin... seems to enable nemotron too? |
… into nemotron-eagle3-support Brings in @mgoin's PR vllm-project#39949 (vllm-project/vllm) which lets the extract_hidden_states speculative method work on hybrid attention backbones (Mamba-2 / GatedDeltaNet + attention). Why pull this in early: * Validated end-to-end against nvidia/NVIDIA-Nemotron-3-Nano-4B-BF16 on a DGX Spark (GB10): vLLM serve OK, /health 200, speculators online DFlash training runs ~150 real loss steps before tuning-related NaN. Without vllm-project#39949, vLLM cannot even start the verifier (NotImplementedError on page-size unification). * Composes cleanly with this branch's NemotronH SupportsEagle3 hooks (5925cca..f92ca38) -- no overlap in files, single conflict was in gpu_model_runner.py only. Conflict resolution: * vllm/v1/worker/gpu_model_runner.py -- kept the upstream/main version of the as_strided branch in _reshape_kv_cache_tensors. Upstream refactored this path to gate on `kv_cache_spec.page_size_padded is not None` and to allocate strides via `torch.empty(kv_cache_shape). stride()`, which is functionally equivalent to the PR's `> real_page_size_bytes` check + manual `inner.stride()` on this spec but more general (already in upstream main, post-dates the PR). All other bits of the PR (HiddenStateCacheSpec marker class, SupportsHMA on ExampleHiddenStatesConnector, kv_cache_utils group filter+add-back, attn_metadata.slot_mapping, proposer kv_cache_gid) apply unchanged. Verified post-merge: python ast.parse on all 8 changed files -> OK HiddenStateCacheSpec class present at vllm/v1/kv_cache_interface.py:376 ExampleHiddenStatesConnector(KVConnectorBase_V1, SupportsHMA) confirmed If/when vllm-project#39949 lands upstream, the merge will resolve to a no-op for this branch; nothing here forks the PR. Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Johnny Nunez <johnnynuca14@gmail.com> Made-with: Cursor Signed-off-by: johnnynunez <johnnynuca14@gmail.com>
|
I'm pretty sure the manual striding was merged into gpu model runner for DSV4. Assuming we can piggy-back on that, I think we can get this in without much change. Can you merge main? |
This PR's vllm/config/vllm.py change replaced the unconditional HMA force-disable for kv_transfer_config with a per-connector supports_hma check. Two test patterns started hitting the factory's "HMA enabled but connector doesn't support it" raise: 1. Tests that built VllmConfig with default Nixl (supports_hma) and then mutated kv_transfer_config to a non-HMA connector. By that point __post_init__ had already set disable_hybrid_kv_cache_manager based on Nixl, so the factory raised when the actual connector was built. 2. MultiConnector subclasses SupportsHMA at the class level, so __post_init__ left HMA enabled, but MultiConnector.__init__ asserts every sub-connector supports HMA at runtime. Fixes: - vllm/config/vllm.py: when the connector is MultiConnector, recurse into kv_connector_extra_config["connectors"] and AND-fold their supports_hma so the auto-disable matches the runtime assertion. Unblocks test_multi_connector_mixed_hma_disables_hybrid_kv_cache, which explicitly verifies this behavior. - tests/v1/kv_connector/unit/utils.py: add kv_connector_module_path parameter to create_vllm_config so tests can build the config with the actual external connector class instead of mutating after construction. - 4 unit tests: switch from post-construction kv_transfer_config mutation to construction-time create_vllm_config arguments. __post_init__ then resolves the real connector class and auto-disable kicks in correctly. No disable_hybrid_kv_cache_manager overrides in test code. - .buildkite/test_areas/misc.yaml: add "V1 Extract Hidden States Integration" job covering tests/v1/kv_connector/extract_hidden_states_integration (predictable-Llama and Qwen3.5-0.8B hybrid smoke tests). Signed-off-by: mgoin <mgoin64@gmail.com> Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
5e2b47d to
7a76dd7
Compare
The autouse fixture registers PredictableLlamaForCausalLM in the parent pytest process via ModelRegistry.register_model. CI sets VLLM_WORKER_MULTIPROC_METHOD=spawn, which starts the engine worker as a fresh Python process that doesn't inherit the registration, so the worker fails with "Model architectures ['PredictableLlamaForCausalLM'] are not supported". Override to fork inside the test so the worker inherits the parent's ModelRegistry state. The Qwen3.5 hybrid smoke test in the same file doesn't need this override (it uses a real registered architecture). Signed-off-by: mgoin <mgoin64@gmail.com> Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The HMA auto-disable resolves the kv connector class via KVConnectorFactory.get_connector_class(self.kv_transfer_config), which raises if kv_connector is None. The kv_offloading_backend path leaves kv_connector unset until _post_init_kv_transfer_config() populates it (e.g. "OffloadingConnector", "LMCacheConnectorV1"), so callers using --kv-offloading-backend without --kv-connector hit a ValidationError. Move _post_init_kv_transfer_config() to run before the HMA block. This also fixes a latent issue where the earlier cudagraph-mode connector inspection would see a stale user-provided kv_connector before kv_offloading overwrote it. Signed-off-by: mgoin <mgoin64@gmail.com> Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
This pull request has merge conflicts that must be resolved before it can be |
Resolved two conflicts: - vllm/config/vllm.py: kept early _post_init_kv_transfer_config(), added _verify_kv_transfer_compat() from main - test_backwards_compatibility.py: accepted main's deletion (compat removed) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: mgoin <mgoin64@gmail.com>
|
I'm trying to replicate the setup for training a model like https://huggingface.co/z-lab/Qwen3.5-27B-DFlash and the memory usage is quite unreasonable on AMD, it tells me I need over 200 GB VRAM for a 27B model... Docker: Error message: Things I tried but didn't help:
|
NickLucche
left a comment
There was a problem hiding this comment.
Changes here went a bit too deep with the MultiConnector bit @mgoin :(
| all_support_hma = supports_hma(connector_cls) | ||
| # MultiConnector subclasses SupportsHMA; only effectively | ||
| # supports HMA when every sub-connector does. | ||
| if all_support_hma and connector_cls.__name__ == "MultiConnector": | ||
| sub_ktcs = self.kv_transfer_config.kv_connector_extra_config.get( | ||
| "connectors", [] | ||
| ) | ||
| all_support_hma = all( | ||
| supports_hma( | ||
| KVConnectorFactory.get_connector_class( | ||
| KVTransferConfig(**sub) | ||
| ) | ||
| ) | ||
| for sub in sub_ktcs | ||
| ) | ||
| if not all_support_hma: | ||
| need_disable_hybrid_kv_cache_manager = True | ||
| logger.warning( | ||
| "Turning off hybrid kv cache manager because " | ||
| "connector %s does not subclass `SupportsHMA`. " | ||
| "This will reduce performance on models with " | ||
| "sliding window or Mamba attention. See " | ||
| "kv_connector/v1/base.py for details.", | ||
| connector_cls.__name__, | ||
| ) |
|
I worked around the OOM by setting a lower value of This only happens roughly once every 500 samples, and the sample numbers that are failing are not deterministic. This leads me to think that it's a race condition of some sort. |
|
Ok this happens with |
|
cc @fynnsu |
|
Hmm strange that the hidden state shape is less than the num tokens. I will try to repro this. |
|
If I set |
|
@DarkLight1337 Okay, I was able to reproduce the issue with Claude seems to think its an issue with chunked prefill, and I'm trying to develop a fix. Unfortunately this model doesn't seem to work if you just |
|
This branch has a WIP fix: main...fynnsu:vllm:fix_chunked_prefill_hs_connector That seems to solve the problem but it also makes the save happen after the request finishes so we need to add handling in speculators for this. (To test this I just temporarily hardcoded in a 2-second wait before reading the file in data_generation_offline.py). #37374 also enforces a single save (by requiring no-chunked-prefill) and adds async lock file logic so that the downstream client know when it's safe to read the file. I think we can build a proper fix off of that pr. This should also get around the no-chunked-prefill requirement because instead we just save all the blocks when the request finishes which allows chunked prefill to run as normal. |
This also works on my end, thanks for the quick fix! Let's also put this in #37374.
We could add a check in |
vllm-project#39949) Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
vllm-project#39949) Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
vllm-project#39949) Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
vllm-project#39949) Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… padding with Ascend model runner
### What this PR does / why we need it?
`AscendExtractHiddenStatesProposer` inherits the upstream
`ExtractHiddenStatesProposer._determine_batch_execution_and_padding`
unchanged, which on Ascend causes two distinct failures when running
`extract_hidden_states` on a MoE target model with DP > 1 and
sequence parallelism enabled (e.g. MiniMax-M2 with
`VLLM_ASCEND_ENABLE_FLASHCOMM1=1`):
1. **gloo shape mismatch on the DP cpu_group**:
what(): [enforce fail at .../gloo/transport/tcp/pair.cc:456]
op.preamble.length <= op.nbytes. 8 vs 4.
Received data size doesn't match expected size.
Is there a distributed collective mismatch in your code?
Upstream `coordinate_batch_across_dp` posts a `[4, dp_size]` int32
tensor to the DP cpu_group, while Ascend's main runner uses
`_sync_metadata_across_dp` with a `[2, dp_size]` tensor on the
same cpu_group. The two shapes collide within one step.
2. **reduce_scatter shape-not-divisible assertion on the idle DP rank**:
File ".../vllm_ascend/ops/linear_op.py", line 574, in matmul_and_reduce
output = tensor_model_parallel_reduce_scatter(output_parallel, 0)
File ".../base_device_communicator.py", line 234, in reduce_scatter
assert input_tensor.shape[0] % world_size == 0
AssertionError
The proposer's own `cudagraph_dispatcher` is initialized as
PIECEWISE/NONE only (never FULL), so `dispatch(num_tokens=6)`
returns 6 as-is (no SP padding). That 6 enters DP sync, the synced
max stays 6, and the idle DP rank's main MoE forward then crashes
in SP reduce_scatter because `6 % TP=4 != 0`.
Eagle3/MTP do not reproduce this because Ascend's `AscendEagleProposer`
uses `runner.cudagraph_dispatcher.dispatch(...)` which dispatches
against the runner's FULL-mode capture sizes (always TP-aligned).
### Fix
Override `AscendExtractHiddenStatesProposer._determine_batch_execution_and_padding`:
1. SP-pad `num_tokens` via `runner._pad_for_sequence_parallelism`
before dispatch, so the contribution to DP sync is always
TP-aligned. Mirrors what the runner's main path does at
`model_runner_v1.py:_determine_batch_execution_and_padding`.
2. Use `runner._sync_metadata_across_dp` (packed_tensor shape
`[2, dp_size]`) for DP coordination instead of upstream
`coordinate_batch_across_dp` (shape `[4, dp_size]`), so all DP
collectives in a single step that hit the cpu_group use a
consistent tensor shape.
3. Fail fast at the entry of the override with a clear `AssertionError`
if the proposer was constructed without a `runner` reference,
instead of letting the unguarded `runner._pad_for_sequence_parallelism`
call raise a confusing `AttributeError`.
4. Document the `is_draft_model=True` semantics: it intentionally
makes `should_skip_allreduce_across_dp_group` short-circuit (the
cache-only "draft" here is not MoE), so the call degenerates to a
local broadcast. The actual cross-DP all_reduce has already been
done by the main runner earlier in the step; the SP padding above
is what keeps the value TP-aligned regardless.
### Does this PR introduce _any_ user-facing change?
No user-facing API change. Fixes a runtime crash for users running
`extract_hidden_states` speculative decoding with
`--data-parallel-size > 1` on MoE target models on Ascend NPU.
Single-DP runs and dense target models (e.g. Qwen3-8B) are unaffected.
### How was this patch tested?
Reproduced the crash and verified the fix on MiniMax-M2:
- 2x8 NPU, `--tensor-parallel-size 4 --data-parallel-size 2`
- `--enable-expert-parallel`
- `--compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY"}'`
- `--speculative-config '{"method": "extract_hidden_states", "num_speculative_tokens": 1, "draft_model_config": {"hf_config": {"eagle_aux_hidden_state_layer_ids": [2, 18, 34]}}}'`
- `--kv-transfer-config '{"kv_connector": "ExampleHiddenStatesConnector", "kv_role": "kv_producer", ...}'`
- Sending one `/v1/completions` request with `max_tokens=1`:
- Before fix: idle DP rank crashes on first `execute_dummy_batch`
with the `AssertionError` shown above.
- After fix: request returns 200 OK, hidden_states `.safetensors`
file is written with the expected
`(prompt_len, len(layer_ids), hidden_size)` shape.
Also verified the existing dense Qwen3-8B + extract_hidden_states path
still works unchanged.
Unit tests added in
`tests/ut/spec_decode/test_extract_hidden_states_proposer.py`:
- `test_determine_batch_execution_and_padding_asserts_when_runner_is_none`:
regression guard for the `AttributeError` that would otherwise be
raised on the unguarded `self.runner._pad_for_sequence_parallelism`
call at the entry of the override.
- `test_determine_batch_execution_and_padding_dp1_sp_pads_and_skips_sync`:
with DP=1 the runner's `_pad_for_sequence_parallelism` is still
consulted (so cache_only forward gets an SP-aligned input) but
`_sync_metadata_across_dp` is not called.
- `test_determine_batch_execution_and_padding_dp2_uses_runner_sync`:
with DP>1 the override calls `runner._sync_metadata_across_dp`
with the SP-padded `num_tokens` and `is_draft_model=True`, and does
NOT call the upstream `coordinate_batch_across_dp` (regression
guard for the gloo `8 vs 4` shape mismatch).
- `test_determine_batch_execution_and_padding_dp2_keeps_tp_aligned_for_main_forward`:
regression guard for the `reduce_scatter`
`input.shape[0] % world_size == 0` assertion 闁?the final
`num_tokens_padded` returned to the caller is always TP-aligned.
Related upstream PR: vllm-project/vllm#39949 (introduced
extract_hidden_states speculative method).
Signed-off-by: learning-sketch <learning-sketch@users.noreply.github.com>
vllm-project#39949) Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
Summary
Hidden-state extraction now works on hybrid-attention models (e.g. Qwen3.5). The kv-transfer config no longer force-disables HMA, it stays on for connectors that declare SupportsHMA.
The KV-cache grouping handles the cache-only hidden-state layer alongside a hybrid attention/Mamba layout by filtering it out before unification and adding it back as its own page-aligned 1-layer group, with a strided reshape in the model runner to span the padded page.
ExampleHiddenStatesConnector is updated accordingly, and a new CI job runs the integration tests.
Test plan
tests/v1/kv_connector/extract_hidden_states_integration/test_extraction.py— Llama end-to-end (GPU)extract_hidden_states— hybrid model end-to-end (GPU), hidden states shape[N, 3, 2560]with non-zero valuespre-commit run ruff-check / ruff-format / mypy-3.10— all passing🤖 AI-assisted (Claude)