Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/transformers_utils/configs/speculators/algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def update_dflash(config_dict: dict, pre_trained_config: dict) -> None:
if config_dict.get("target_hidden_size") is not None:
pre_trained_config["target_hidden_size"] = config_dict["target_hidden_size"]

# TODO: does this need to be shifted by 1 like in gpu_model_runner?
aux_layer_ids = config_dict["aux_hidden_state_layer_ids"]
pre_trained_config["eagle_aux_hidden_state_layer_ids"] = aux_layer_ids
Comment on lines +68 to 70
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The TODO should be addressed by applying the +1 shift here. Since gpu_model_runner.py prioritizes the eagle_aux_hidden_state_layer_ids field (line 4937), the fix in the model runner's fallback logic is bypassed for DFlash models configured through this function. Applying the shift here ensures that the correct layer indices are used in the primary execution path.

Suggested change
# TODO: does this need to be shifted by 1 like in gpu_model_runner?
aux_layer_ids = config_dict["aux_hidden_state_layer_ids"]
pre_trained_config["eagle_aux_hidden_state_layer_ids"] = aux_layer_ids
# Add 1 to convert DFlash's aux layer id semantics
aux_layer_ids = [i + 1 for i in config_dict["aux_hidden_state_layer_ids"]]
pre_trained_config["eagle_aux_hidden_state_layer_ids"] = aux_layer_ids


Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4938,7 +4938,8 @@ def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None:
if not layer_ids:
dflash_config = getattr(hf_config, "dflash_config", None)
if dflash_config and isinstance(dflash_config, dict):
layer_ids = dflash_config.get("target_layer_ids")
# Add 1 to convert DFlash's aux layer id semantics
layer_ids = [i + 1 for i in dflash_config.get("target_layer_ids", [])]
Comment on lines +4941 to +4942
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using dflash_config.get("target_layer_ids", []) can lead to a TypeError if the key exists in the dictionary but its value is explicitly set to None. It is safer to use dflash_config.get("target_layer_ids") or [] to ensure the list comprehension always receives an iterable.

Suggested change
# Add 1 to convert DFlash's aux layer id semantics
layer_ids = [i + 1 for i in dflash_config.get("target_layer_ids", [])]
# Add 1 to convert DFlash's aux layer id semantics
layer_ids = [i + 1 for i in (dflash_config.get("target_layer_ids") or [])]


if layer_ids and isinstance(layer_ids, (list, tuple)):
return tuple(layer_ids)
Expand Down
Loading