Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions tests/ut/spec_decode/test_eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def test_load_model_pp1(self, mock_pp_group, mock_get_model,
weight = torch.zeros(0)

mock_model = MagicMock()
mock_model.supports_multimodal = False
mock_model.lm_head = MagicMock()
mock_model.multimodal_cpu_fields = None
mock_model.merge_by_field_config = None
Expand All @@ -193,7 +194,7 @@ def test_load_model_pp1(self, mock_pp_group, mock_get_model,

self.proposer.load_model(mock_model)
mock_get_model.assert_called_once()
self.assertEqual(self.proposer.attn_layer_name, ["layer3"])
self.assertEqual(self.proposer.attn_layer_names, ["layer3"])
self.assertIs(self.proposer.model.model.embed_tokens,
mock_model.model.embed_tokens)

Expand Down Expand Up @@ -224,7 +225,7 @@ def test_load_model_pp_gt1(self, mock_pp_group, mock_get_model,

self.assertIsNot(self.proposer.model.model.embed_tokens,
mock_model.model.embed_tokens)
self.assertEqual(self.proposer.attn_layer_name, ["layer2"])
self.assertEqual(self.proposer.attn_layer_names, ["layer2"])

@patch(
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
Expand Down Expand Up @@ -254,7 +255,7 @@ def test_load_model_multimodal(self, mock_supports_multi, mock_pp_group,
self.proposer.name = SpecDcodeType.EAGLE

self.proposer.load_model(mock_model)
mock_model.get_language_model.assert_called_once()
self.assertEqual(mock_model.get_language_model.call_count, 2)
self.assertIs(self.proposer.model.lm_head,
mock_model.get_language_model.return_value.lm_head)

Expand Down Expand Up @@ -435,4 +436,4 @@ def test_prepare_inputs(self):
torch.tensor([1, 2, 4]))):
return_attn, indices = self.proposer.prepare_inputs(
mock_attn, num_rejected)
self.assertEqual(indices.tolist(), [1, 2, 4])
self.assertEqual(indices.tolist(), [1, 2, 4])
82 changes: 66 additions & 16 deletions vllm_ascend/spec_decode/eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,31 +197,81 @@ def load_model(self, model: nn.Module) -> None:
draft_indexer_layer_names = indexer_layers - target_indexer_layer_names
draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names
assert len(draft_attn_layer_names) == 1
self.attn_layer_name = list(sorted(draft_attn_layer_names))
self.attn_layer_names = self.attn_layer_name
self.attn_layer_names = list(sorted(draft_attn_layer_names))

if supports_multimodal(model):
# handle multimodality
if self.get_model_name(model) in [
"Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration",
]:
self.model.config.image_token_index = model.config.image_token_id
elif self.get_model_name(
model) == "PixtralForConditionalGeneration":
self.model.config.image_token_index = (
model.config.vision_config.image_token_id)
else:
self.model.config.image_token_index = (
model.config.image_token_index)
target_language_model = model.get_language_model()
else:
target_language_model = model

# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1:
# If pp>1, the weights of mtp and the main model's embedding are not on the same device.
# check if mtp model use main model's embedding and LMhead
if hasattr(model, "model") and hasattr(model.model, "embed_tokens") and \
torch.equal(self.model.model.embed_tokens.weight,
model.model.embed_tokens.weight):
logger.info(
"The EAGLE head shares the same vocab embedding" \
" with the target model."
if hasattr(target_language_model.model, "embed_tokens"):
target_embed_tokens = target_language_model.model.embed_tokens
elif hasattr(target_language_model.model, "embedding"):
target_embed_tokens = target_language_model.model.embedding
else:
raise AttributeError(
"Target model does not have 'embed_tokens' or 'embedding' attribute"
)
self.model.model.embed_tokens = model.model.embed_tokens

share_embeddings = False
if hasattr(self.model, "has_own_embed_tokens"):
# EAGLE model
if not self.model.has_own_embed_tokens:
share_embeddings = True
logger.info(
"Detected EAGLE model without its own embed_tokens in the"
" checkpoint. Sharing target model embedding weights with the"
" draft model.")
elif (isinstance(target_embed_tokens.weight, torch.Tensor)
and isinstance(self.model.model.embed_tokens.weight,
torch.Tensor)
# TODO: Offload to CPU for comparison to avoid extra NPU memory
# usage in CI testing environments with limited NPU memory
and torch.equal(
target_embed_tokens.weight.cpu(),
self.model.model.embed_tokens.weight.cpu(),
)):
share_embeddings = True
logger.info(
"Detected EAGLE model with embed_tokens identical to the target"
" model. Sharing target model embedding weights with the draft"
" model.")
else:
logger.info(
"Detected EAGLE model with distinct embed_tokens weights. "
"Keeping separate embedding weights from the target model."
)
else:
# MTP model
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow the original logic.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, it's my mistake.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

share_embeddings = True
logger.info(
" The EAGLE head loaded its own vocab embedding" \
" weights instead of sharing them with the target model."
"Detected MTP model. "
"Sharing target model embedding weights with the draft model."
)

if share_embeddings:
if hasattr(self.model.model, "embed_tokens"):
del self.model.model.embed_tokens
self.model.model.embed_tokens = target_embed_tokens
else:
logger.info(
"Since PP > 1 or other reasons the model head loaded its own vocab embedding" \
" weights instead of sharing them with the target model."
)
"The draft model's vocab embedding will be loaded separately"
" from the target model.")

# share lm_head with the target model if needed
# some model definition do not define lm_head explicitly
Expand Down
8 changes: 4 additions & 4 deletions vllm_ascend/spec_decode/mtp_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def dummy_run(self,
attn_metadata_mtp = builder.build_for_graph_capture(
common_attn_metadata, attn_state)
attn_metadata = {}
for layer_name in self.attn_layer_name:
for layer_name in self.attn_layer_names:
attn_metadata[layer_name] = attn_metadata_mtp
else:
attn_metadata = None
Expand Down Expand Up @@ -302,7 +302,7 @@ def _propose(
attn_metadata_mtp = builder.build(0, common_attn_metadata,
self.runner.get_model())
attn_metadata = {}
for layer_name in self.attn_layer_name:
for layer_name in self.attn_layer_names:
attn_metadata[layer_name] = attn_metadata_mtp

for step in range(self.num_speculative_tokens):
Expand Down Expand Up @@ -331,7 +331,7 @@ def _propose(
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
hidden_states)

for layer_name in self.attn_layer_name:
for layer_name in self.attn_layer_names:
decode_metadata = getattr(attn_metadata[layer_name],
"decode", None)
if self.use_async_scheduling and decode_metadata is not None:
Expand Down Expand Up @@ -402,7 +402,7 @@ def _propose(
if step == self.num_speculative_tokens - 1 or with_prefill:
break

attn_metadata_i = attn_metadata[self.attn_layer_name[0]]
attn_metadata_i = attn_metadata[self.attn_layer_names[0]]

if step == 0:
positions = target_positions[last_token_indices]
Expand Down
Loading