Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
047e904
support qwen3-vl eagle3
jesse996 Dec 10, 2025
2a698c0
add test
jesse996 Dec 16, 2025
44f90da
lint
jesse996 Dec 16, 2025
b9e2b6f
fix tests
jesse996 Dec 16, 2025
80c41ee
lint
jesse996 Dec 16, 2025
73297eb
Merge branch 'main' into eagle-qwen3vl
jesse996 Dec 18, 2025
9bece76
Merge branch 'main' into eagle-qwen3vl
jesse996 Dec 18, 2025
16e76f7
Merge branch 'main' into eagle-qwen3vl
jesse996 Dec 19, 2025
1dbfe1a
Merge branch 'main' into eagle-qwen3vl
jesse996 Dec 21, 2025
af060c2
Merge branch 'main' into eagle-qwen3vl
jesse996 Dec 22, 2025
fea09b0
Merge branch 'main' into eagle-qwen3vl
jesse996 Dec 29, 2025
dd7f9a6
Merge branch 'main' into eagle-qwen3vl
jesse996 Dec 29, 2025
1e66771
Merge branch 'main' into eagle-qwen3vl
jesse996 Dec 31, 2025
9a58127
Merge branch 'main' into eagle-qwen3vl
jesse996 Dec 31, 2025
fdf8afb
Merge branch 'main' into eagle-qwen3vl
jesse996 Jan 5, 2026
74dc940
Merge branch 'main' into eagle-qwen3vl
jesse996 Jan 6, 2026
1af997c
Merge branch 'main' into eagle-qwen3vl
jesse996 Jan 8, 2026
30aef44
Merge branch 'main' into eagle-qwen3vl
jesse996 Jan 14, 2026
1874acb
Merge branch 'main' into eagle-qwen3vl
jesse996 Jan 14, 2026
a760e14
fix modelrunner positions
jesse996 Jan 14, 2026
8d34834
Merge branch 'main' into eagle-qwen3vl
jesse996 Jan 14, 2026
5c30916
update test
jesse996 Jan 15, 2026
21631a7
gh add new model
jesse996 Jan 15, 2026
103d4a3
Merge branch 'main' into eagle-qwen3vl
jesse996 Jan 15, 2026
880dafe
Merge branch 'main' into eagle-qwen3vl
jesse996 Jan 16, 2026
c8b54d7
Merge branch 'main' into eagle-qwen3vl
jesse996 Jan 16, 2026
928d38d
update
jesse996 Jan 16, 2026
8c72466
Merge branch 'main' into eagle-qwen3vl
jesse996 Jan 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/misc/model_list.json
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
"moonshotai/Kimi-K2-Thinking",
"moonshotai/Kimi-Linear-48B-A3B-Instruct",
"neuralmagic/Qwen2.5-3B-quantized.w8a8",
"MNN/Qwen3-VL-8B-Instruct-Eagle3",
"nv-community/audio-flamingo-3",
"nv-community/audio-flamingo-3-hf",
"nvidia/audio-flamingo-3-hf",
Expand Down Expand Up @@ -234,4 +235,3 @@
"xlangai/OpenCUA-7B"
]
}

50 changes: 50 additions & 0 deletions tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ def eagle3_model_name():
return "vllm-ascend/EAGLE3-LLaMA3.1-Instruct-8B"


@pytest.fixture
def vl_model_name():
return "Qwen/Qwen3-VL-8B-Instruct"

def vl_eagle3_model_name():
return "MNN/Qwen3-VL-8B-Instruct-Eagle3"


def test_ngram_correctness(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
Expand Down Expand Up @@ -129,6 +137,48 @@ def test_ngram_correctness(
assert matches > int(0.66 * len(ref_outputs))


def test_qwen3_vl_eagle_correctness(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
vl_model_name: str,
):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.
'''
with VllmRunner(
vl_model_name,
max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8],
) as ref_llm:
ref_outputs = ref_llm.model.chat(test_prompts, sampling_config)

spec_model_name = vl_eagle3_model_name()
with VllmRunner(
vl_model_name,
speculative_config={
"method": "eagle3",
"model": spec_model_name,
"num_speculative_tokens": 2,
},
max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8],
) as runner:
spec_outputs = runner.model.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")

# Heuristic: expect at least 70% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))

def test_suffix_correctness(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
Expand Down
4 changes: 3 additions & 1 deletion tests/ut/spec_decode/test_eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def test_initialization_eagle_graph(self):
self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096
self.vllm_config.compilation_config.mode = CompilationMode.VLLM_COMPILE
self.vllm_config.model_config.enforce_eager = False
self.vllm_config.model_config.uses_mrope = False
self.vllm_config.speculative_config.enforce_eager = False
self.vllm_config.scheduler_config.async_scheduling = False
init_ascend_config(self.vllm_config)
Expand Down Expand Up @@ -156,6 +157,7 @@ def test_load_model_pp1(self, mock_pp_group, mock_get_model,
}]

mock_model = MagicMock()
mock_model.supports_multimodal = False
mock_model.model.embed_tokens = MagicMock()
mock_model.lm_head = MagicMock()
mock_model.multimodal_cpu_fields = None
Expand Down Expand Up @@ -226,7 +228,7 @@ def test_load_model_multimodal(self, mock_supports_multi, mock_pp_group,
self.proposer.name = SpecDcodeType.EAGLE

self.proposer.load_model(mock_model)
mock_model.get_language_model.assert_called_once()
self.assertEqual(mock_model.get_language_model.call_count, 2)
self.assertIs(self.proposer.model.lm_head,
mock_model.get_language_model.return_value.lm_head)

Expand Down
110 changes: 87 additions & 23 deletions vllm_ascend/spec_decode/eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __init__(self,
self.dcp_size = self.runner.dcp_size
self.pcp_rank = self.runner.pcp_rank
self.dcp_rank = self.runner.dcp_rank

self.use_aclgraph = self.runner._use_aclgraph()

self.full_indices = range(
Expand Down Expand Up @@ -149,8 +149,34 @@ def load_model(self, model: nn.Module) -> None:
assert len(draft_attn_layer_names) == 1
self.attn_layer_names = list(draft_attn_layer_names)

if supports_multimodal(model):
# handle multimodality
if self.get_model_name(model) in [
"Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration",
]:
self.model.config.image_token_index = model.config.image_token_id
elif self.get_model_name(
model) == "PixtralForConditionalGeneration":
self.model.config.image_token_index = (
model.config.vision_config.image_token_id)
else:
self.model.config.image_token_index = (
model.config.image_token_index)
target_language_model = model.get_language_model()
else:
target_language_model = model

# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1:
if hasattr(target_language_model.model, "embed_tokens"):
target_embed_tokens = target_language_model.model.embed_tokens
elif hasattr(target_language_model.model, "embedding"):
target_embed_tokens = target_language_model.model.embedding
else:
raise AttributeError(
"Target model does not have 'embed_tokens' or 'embedding' attribute"
)
if self.method == "mtp":
if self.vllm_config.model_config.is_deepseek_mla and \
torch.equal(self.model.model.embed_tokens.weight,
Expand All @@ -161,7 +187,7 @@ def load_model(self, model: nn.Module) -> None:
"The MTP head shares the same vocab embedding" \
" with the target model."
)
self.model.model.embed_tokens = model.model.embed_tokens
self.model.model.embed_tokens = target_embed_tokens
else:
logger.info(
" The MTP head loaded its own vocab embedding" \
Expand All @@ -172,13 +198,12 @@ def load_model(self, model: nn.Module) -> None:
"The EAGLE head shares the same vocab embedding" \
" with the target model."
)
self.model.model.embed_tokens = model.model.embed_tokens
self.model.model.embed_tokens = target_embed_tokens
else:
logger.info(
"Since PP > 1 or other reasons the model head loaded its own vocab embedding" \
" weights instead of sharing them with the target model."
)

# share lm_head with the target model if needed
# some model definition do not define lm_head explicitly
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
Expand Down Expand Up @@ -221,7 +246,7 @@ def dummy_run(self,
dummy_compute_logits=lambda hidden_states: None,
is_profile=False):
# update global cos, sin
update_cos_sin(self.positions[:num_tokens])
update_cos_sin(self._get_positions(num_tokens))

attn_metadata = None
if not self.use_cuda_graph:
Expand Down Expand Up @@ -265,7 +290,7 @@ def dummy_run(self,
attn_metadata[layer_name] = attn_metadata_eagle

model_input_ids = self.input_ids[:num_tokens]
model_positions = self.positions[:num_tokens]
model_positions = self._get_positions(num_tokens)
model_previous_hidden_states = self.hidden_states[:num_tokens]
for i in range(self.num_speculative_tokens):
if i > 0 and in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode.FULL:
Expand Down Expand Up @@ -340,7 +365,6 @@ def _propose(
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[last_token_indices] = next_token_ids

if self.use_cuda_graph and \
num_tokens <= self.runner.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
Expand All @@ -356,15 +380,28 @@ def _propose(
batch_descriptor = None

# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
self._set_positions(num_tokens, target_positions)
self.hidden_states[:num_tokens] = target_hidden_states

if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
inputs_embeds = self.model.embed_input_ids(
self.input_ids[:num_tokens],
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed)
self.inputs_embeds[:num_tokens] = inputs_embeds
inputs_embeds = self.inputs_embeds[:num_input_tokens]
input_ids = self.input_ids[:num_input_tokens]
else:
inputs_embeds = None
input_ids = self.input_ids[:num_input_tokens]

# FIXME(woosuk): The below two ops cause synchronization. Optimize.
builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata = builder.build(0, common_attn_metadata,
self.runner.get_model())
# update global cos, sin
update_cos_sin(self.positions[:num_input_tokens])
update_cos_sin(self._get_positions(num_input_tokens))
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
Expand All @@ -380,7 +417,7 @@ def _propose(
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings.
# `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model.
model_input_ids = self.input_ids[:num_input_tokens]
model_positions = self.positions[:num_input_tokens]
model_positions = self._get_positions(num_input_tokens)
model_hidden_states = self.hidden_states[:num_input_tokens]

model_hidden_states, model_positions = self.maybe_pad_and_reduce(
Expand All @@ -390,6 +427,7 @@ def _propose(
input_ids=model_input_ids,
positions=model_positions,
hidden_states=model_hidden_states,
inputs_embeds = inputs_embeds
)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
Expand Down Expand Up @@ -420,8 +458,10 @@ def _propose(
dtype=draft_token_ids.dtype,
device=self.device)
draft_token_ids_tensor[0] = draft_token_ids

positions = target_positions[last_token_indices]
if self.uses_mrope:
positions = target_positions[:, last_token_indices]
else:
positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
last_token_indices = self.arange[:batch_size]

Expand Down Expand Up @@ -460,11 +500,18 @@ def _propose(
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len = positions >= self.vllm_config.model_config.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)
if self.uses_mrope:
exceeds_max_model_len = positions[
0] >= self.vllm_config.model_config.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(
exceeds_max_model_len.unsqueeze(0),
torch.zeros_like(positions), positions)
else:
exceeds_max_model_len = positions >= self.vllm_config.model_config.max_model_len
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)

# TODO: Increment the sequence lengths.

Expand All @@ -485,12 +532,19 @@ def _propose(
block_size = attn_metadata_builder.kv_cache_spec.block_size

# Compute the slot mapping.
block_numbers = (clamped_positions // block_size)
if self.uses_mrope:
block_numbers = clamped_positions[0] // block_size
else:
block_numbers = (clamped_positions // block_size)
block_ids = attn_metadata.block_tables.gather(
dim=1, index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
slot_mapping_tmp = (block_ids * block_size +
clamped_positions % block_size)
if self.uses_mrope:
slot_mapping_tmp = (block_ids * block_size +
clamped_positions[0] % block_size)
else:
slot_mapping_tmp = (block_ids * block_size +
clamped_positions % block_size)

# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
Expand All @@ -504,14 +558,23 @@ def _propose(
PADDING_SLOT_ID)
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self._set_positions(batch_size, clamped_positions)
self.hidden_states[:batch_size] = hidden_states
if self.supports_mm_inputs:
self.inputs_embeds[:batch_size] = self.model.embed_input_ids(
input_ids)

input_ids = self.input_ids[:input_batch_size]
inputs_embeds = self.inputs_embeds[:input_batch_size]
else:
input_ids = self.input_ids[:input_batch_size]
inputs_embeds = None
attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask()

attn_metadata.attn_mask = attn_mask

# update global cos, sin
update_cos_sin(self.positions[:input_batch_size])
update_cos_sin(self._get_positions(input_batch_size))

# Run the model.
with set_ascend_forward_context(
Expand All @@ -526,7 +589,7 @@ def _propose(
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings.
# `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model.
model_input_ids = self.input_ids[:input_batch_size]
model_positions = self.positions[:input_batch_size]
model_positions = self._get_positions(input_batch_size)
model_hidden_states = self.hidden_states[:input_batch_size]

model_hidden_states, model_positions = self.maybe_pad_and_reduce(
Expand All @@ -536,6 +599,7 @@ def _propose(
input_ids=model_input_ids,
positions=model_positions,
hidden_states=model_hidden_states,
inputs_embeds = inputs_embeds
)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
Expand Down
8 changes: 4 additions & 4 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,14 +1354,14 @@ def propose_draft_token_ids(
query_start_loc_pcp_full[1:num_reqs + 1] - 1
target_token_ids = input_ids_pcp_full[:
num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_positions = self._get_positions(num_scheduled_tokens)
target_hidden_states = hidden_states
else:
token_indices_to_sample = None
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids.gpu[:
num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_positions = self._get_positions(num_scheduled_tokens)
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat([
h[:num_scheduled_tokens]
Expand Down Expand Up @@ -1402,7 +1402,7 @@ def propose_draft_token_ids(
target_hidden_states = hidden_states
else:
target_token_ids = self.input_ids.gpu[token_indices]
target_positions = positions[token_indices]
target_positions = self._get_positions(token_indices)
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states],
Expand Down Expand Up @@ -3006,7 +3006,7 @@ def capture_model(self) -> None:
def _prepare_multimodal_fields(self):
"""
Ensures specific multimodal tensors are on CPU.
This is necessary for fields like 'grid_thw' which are converted to numpy
This is necessary for fields like 'grid_thw' which are converted to numpy
inside the model's forward pass.
"""
if not self.multimodal_cpu_fields:
Expand Down