Skip to content

[Perf][Bugfix] Update dflash aux layer indexing#40727

Open
benchislett wants to merge 1 commit intovllm-project:mainfrom
CentML:dflash-ar-bugfix
Open

[Perf][Bugfix] Update dflash aux layer indexing#40727
benchislett wants to merge 1 commit intovllm-project:mainfrom
CentML:dflash-ar-bugfix

Conversation

@benchislett
Copy link
Copy Markdown
Collaborator

Purpose

A discrepancy in indexing causes a slight gap in the acceptance rates for DFlash v.s. the reference.

See: https://github.com/z-lab/dflash/blob/main/dflash/model.py#L44

This will have implications for Speculators, not sure if they also have this issue.

Test Plan

Existing dflash AR test passes and acceptance rate goes up.

Test Result

Baseline vLLM:

DFlash acceptance_len for mt-bench: [3.92, 4.03] (expected at least 3.82)
DFlash acceptance_len for humaneval: [6.29, 6.30] (expected at least 5.85)
DFlash acceptance_len for gsm8k (subset of 200 prompts): 6.18 (expected at least 5.59)

Offset by +2:

DFlash acceptance_len for mt-bench: 3.95 (expected at least 3.82)
DFlash acceptance_len for humaneval: 6.06 (expected at least 5.85)
DFlash acceptance_len for gsm8k (subset of 200 prompts): 6.08 (expected at least 5.59)

Offset by +1:
DFlash acceptance_len for mt-bench: [4.11, 4.16] (expected at least 3.82)
DFlash acceptance_len for humaneval: [6.44, 6.48] (expected at least 5.85)
DFlash acceptance_len for gsm8k (subset of 200 prompts): [6.37, 6.42] (expected at least 5.59)

Reference data from the paper (https://arxiv.org/pdf/2602.06036):
"mt-bench": 4.24,
"humaneval": 6.50,
"gsm8k": 6.54

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@benchislett benchislett requested a review from njhill as a code owner April 23, 2026 18:34
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added v1 bug Something isn't working labels Apr 23, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a +1 shift for DFlash auxiliary layer IDs in the GPU model runner to align with expected semantics. Feedback suggests applying this shift directly in the configuration utility to ensure consistency across execution paths and hardening the list comprehension against potential null values in the configuration dictionary.

Comment on lines +68 to 70
# TODO: does this need to be shifted by 1 like in gpu_model_runner?
aux_layer_ids = config_dict["aux_hidden_state_layer_ids"]
pre_trained_config["eagle_aux_hidden_state_layer_ids"] = aux_layer_ids
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The TODO should be addressed by applying the +1 shift here. Since gpu_model_runner.py prioritizes the eagle_aux_hidden_state_layer_ids field (line 4937), the fix in the model runner's fallback logic is bypassed for DFlash models configured through this function. Applying the shift here ensures that the correct layer indices are used in the primary execution path.

Suggested change
# TODO: does this need to be shifted by 1 like in gpu_model_runner?
aux_layer_ids = config_dict["aux_hidden_state_layer_ids"]
pre_trained_config["eagle_aux_hidden_state_layer_ids"] = aux_layer_ids
# Add 1 to convert DFlash's aux layer id semantics
aux_layer_ids = [i + 1 for i in config_dict["aux_hidden_state_layer_ids"]]
pre_trained_config["eagle_aux_hidden_state_layer_ids"] = aux_layer_ids

Comment on lines +4941 to +4942
# Add 1 to convert DFlash's aux layer id semantics
layer_ids = [i + 1 for i in dflash_config.get("target_layer_ids", [])]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using dflash_config.get("target_layer_ids", []) can lead to a TypeError if the key exists in the dictionary but its value is explicitly set to None. It is safer to use dflash_config.get("target_layer_ids") or [] to ensure the list comprehension always receives an iterable.

Suggested change
# Add 1 to convert DFlash's aux layer id semantics
layer_ids = [i + 1 for i in dflash_config.get("target_layer_ids", [])]
# Add 1 to convert DFlash's aux layer id semantics
layer_ids = [i + 1 for i in (dflash_config.get("target_layer_ids") or [])]

repne pushed a commit to repne/vllm that referenced this pull request Apr 26, 2026
Threads layer_types and sliding_window through the DFlash Qwen3 drafter
so target models with SWA layers can be drafted correctly:

- Per-layer SWA in qwen3_dflash: builds Attention layers with
  sliding_window for sliding_attention entries in layer_types, exposes
  sliding_attention_layer_names for the proposer.
- Speculators config: preserve layer_types, use_sliding_window,
  sliding_window, max_window_layers when extracting the HF config.
- DFlash proposer: force causal=True on the per-layer attention
  metadata for SWA layers so the windowed kernel runs correctly during
  parallel block drafting.

Built on top of vllm-project#40727 (target_layer_ids +1 shift). The shift in
gpu_model_runner.py here overlaps with vllm-project#40727 and can be dropped once
that lands.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant