Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions vllm/model_executor/models/llama_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
get_draft_quant_config,
maybe_prefix,
process_eagle_weight,
get_rotation_path,
get_rotataion_matrix,
compute_rotataion_matrix3,
)

logger = init_logger(__name__)
Expand Down Expand Up @@ -266,6 +269,10 @@
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
self.config = vllm_config.speculative_config.draft_model_config.hf_config

self.target_model_config = self.vllm_config.speculative_config.target_model_config

Check failure on line 273 in vllm/model_executor/models/llama_eagle3.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/llama_eagle3.py:273:89: E501 Line too long (90 > 88)
self.target_quant_config = self.vllm_config.quant_config

# Ensure draft_vocab_size is set
# default to the base vocab size when absent
if getattr(self.config, "draft_vocab_size", None) is None:
Expand Down Expand Up @@ -360,6 +367,19 @@
return self.model.fc(hidden_states)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# TODO maybe extract a function
rotation_path = get_rotation_path(self.target_model_config.model, self.target_quant_config)

Check failure on line 371 in vllm/model_executor/models/llama_eagle3.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/llama_eagle3.py:371:89: E501 Line too long (99 > 88)

use_quarot = rotation_path is not None

if use_quarot:
Q = get_rotataion_matrix(rotation_path)
Q3 = compute_rotataion_matrix3(Q)
if isinstance(self.config.dtype, str):
embed_dtype = getattr(torch, self.config.dtype)
else:
embed_dtype = self.config.dtype

Check failure on line 381 in vllm/model_executor/models/llama_eagle3.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

vllm/model_executor/models/llama_eagle3.py:381:17: F841 Local variable `embed_dtype` is assigned to but never used

model_weights = {}
includes_draft_id_mapping = False
includes_embed_tokens = False
Expand All @@ -384,11 +404,24 @@
continue
elif "lm_head" not in name:
name = "model." + name
if "fc." in name and use_quarot:
# anti-rotate fc
dtype = loaded_weight.dtype
loaded_weight = loaded_weight @ Q3.to(dtype)
if "embed_tokens" in name:
includes_embed_tokens = True
model_weights[name] = loaded_weight
process_eagle_weight(self, name)

# process embedding if drafter does not have embedding
if use_quarot and not includes_embed_tokens:
name = "model.embed_tokens.weight"
loaded_weight = get_embedding_tensor(target_model_path).to(dtype) @ Q.T.to(dtype)

Check failure on line 419 in vllm/model_executor/models/llama_eagle3.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/llama_eagle3.py:419:89: E501 Line too long (93 > 88)

Check failure on line 419 in vllm/model_executor/models/llama_eagle3.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/model_executor/models/llama_eagle3.py:419:50: F821 Undefined name `target_model_path`

Check failure on line 419 in vllm/model_executor/models/llama_eagle3.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/model_executor/models/llama_eagle3.py:419:29: F821 Undefined name `get_embedding_tensor`
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function get_embedding_tensor is called here but it is not defined or imported in this file, nor is it defined in this pull request. This will cause a NameError at runtime.

model_weights[name] = loaded_weight

includes_embed_tokens = True
process_eagle_weight(self, name)

if not includes_mask_hidden and self.use_parallel_drafting:
raise ValueError(
"mask_hidden not found in weights but "
Expand Down
37 changes: 37 additions & 0 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,3 +875,40 @@
if feature_layer_index < 0:
return num_hidden_layers + feature_layer_index + 1
return feature_layer_index


def get_rotation_path(target_model_path, quant_config):
"""
Gets the path of the rotation matrix, returns None if the target model is not a quarot model.

Check failure on line 882 in vllm/model_executor/models/utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/utils.py:882:89: E501 Line too long (97 > 88)
"""
try:
rotation_relative_path = quant_config.quant_description["optional"]["quarot"]["rotation_map"][

Check failure on line 885 in vllm/model_executor/models/utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/utils.py:885:89: E501 Line too long (102 > 88)
"global_rotation"
]
except KeyError:
return None

return Path(target_model_path) / rotation_relative_path

Check failure on line 891 in vllm/model_executor/models/utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/model_executor/models/utils.py:891:12: F821 Undefined name `Path`


def get_rotataion_matrix(rotation_path):
"""
Anti-rotate maxtrix.
"""
Comment on lines +894 to +897
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's a typo in the function name (get_rotataion_matrix should be get_rotation_matrix). Also, the docstring "Anti-rotate maxtrix." is not very descriptive and contains a typo (maxtrix). Please correct the function name and provide a more informative docstring.

Suggested change
def get_rotataion_matrix(rotation_path):
"""
Anti-rotate maxtrix.
"""
def get_rotation_matrix(rotation_path):
"""
Loads the anti-rotation matrix Q from the given path.
"""

try:
safetensor_data = load_file(rotation_path)

Check failure on line 899 in vllm/model_executor/models/utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/model_executor/models/utils.py:899:27: F821 Undefined name `load_file`
Q = safetensor_data["global_rotation"]

return Q
except Exception as e:
logger.error(
f"Failed to load rotation weight from '{rotation_path}'. "
"If you want to use quarot model with eagle3, take a check."
)
raise e

def compute_rotataion_matrix3(Q):
"""
Anti-rotate matrix for 3 layers of hidden_states.
"""
return torch.block_diag(Q, Q, Q)