Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion examples/offline_inference/spec_decode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json

from transformers import AutoTokenizer

Expand Down Expand Up @@ -54,9 +55,14 @@ def parse_args():
"--method",
type=str,
default="eagle",
choices=["ngram", "eagle", "eagle3", "mtp"],
choices=["ngram", "eagle", "eagle3", "mtp", "ngram-eagle"],
)
parser.add_argument("--num-spec-tokens", type=int, default=2)
parser.add_argument(
"--num-speculative-tokens-per-method",
type=json.loads,
default='{"ngram": 2, "eagle": 2}',
)
parser.add_argument("--prompt-lookup-max", type=int, default=5)
parser.add_argument("--prompt-lookup-min", type=int, default=2)
parser.add_argument("--tp", type=int, default=1)
Expand Down Expand Up @@ -119,6 +125,21 @@ def main(args):
"prompt_lookup_max": args.prompt_lookup_max,
"prompt_lookup_min": args.prompt_lookup_min,
}
elif args.method == "ngram-eagle":
eagle_dir = args.eagle_dir
if eagle_dir is None:
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
args.num_spec_tokens = max(
args.num_speculative_tokens_per_method["ngram"],
args.num_speculative_tokens_per_method["eagle"],
)
speculative_config = {
"method": "ngram-eagle",
"model": eagle_dir,
"num_speculative_tokens_per_method": args.num_speculative_tokens_per_method,
"prompt_lookup_max": args.prompt_lookup_max,
"prompt_lookup_min": args.prompt_lookup_min,
}
elif args.method == "mtp":
speculative_config = {
"method": "mtp",
Expand Down Expand Up @@ -156,6 +177,7 @@ def main(args):
print("-" * 50)
print(f"prompt: {output.prompt}")
print(f"generated text: {output.outputs[0].text}")
print(f"num of generated tokens: {len(output.outputs[0].token_ids)}")
print("-" * 50)

try:
Expand Down Expand Up @@ -185,6 +207,10 @@ def main(args):
assert isinstance(metric, Vector)
for pos in range(len(metric.values)):
acceptance_counts[pos] += metric.values[pos]
elif metric.name == "vllm:generation_tokens":
assert isinstance(metric, Counter)
print(f"num generation tokens: {metric.value}")
total_tokens_generated = metric.value

print("-" * 50)
print(f"total_num_output_tokens: {total_num_output_tokens}")
Expand All @@ -193,6 +219,14 @@ def main(args):
print(f"num_accepted_tokens: {num_accepted_tokens}")
acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1
print(f"mean acceptance length: {acceptance_length:.2f}")
num_tokens_generated_without_sd = total_tokens_generated - (
num_drafts + num_accepted_tokens
)
seq_normalized_acceptance_length = (total_tokens_generated) / (
num_drafts + num_tokens_generated_without_sd
)
print(f"num_tokens_generated_without_sd: {num_tokens_generated_without_sd}")
print(f"seq normalized acceptance length: {seq_normalized_acceptance_length:.2f}")
print("-" * 50)

# print acceptance at each token position
Expand Down
36 changes: 28 additions & 8 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def test_ngram_correctness(
"head_dim not being a a multiple of 32")),
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
(("ngram-eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
Expand All @@ -150,8 +152,10 @@ def test_ngram_correctness(
"eagle618/eagle-deepseek-v3-random", 1), False),
],
ids=[
"qwen3_eagle3", "qwen2_5_vl_eagle3", "llama3_eagle", "llama3_eagle3",
"llama4_eagle", "llama4_eagle_mm", "deepseek_eagle"
"qwen3_eagle3", "qwen2_5_vl_eagle3",
"llama3_eagle", "llama3_ngram_eagle",
"llama3_eagle3", "llama4_eagle",
"llama4_eagle_mm", "deepseek_eagle"
])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
Expand Down Expand Up @@ -202,16 +206,32 @@ def test_eagle_correctness(
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

spec_llm = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp_size,
speculative_config={
if method == "ngram-eagle":
# Use ngram-eagle specific config
speculative_config = {
"method": method,
"model": spec_model_name,
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens_per_method": {
"ngram": 3,
"eagle": 3
},
"max_model_len": 2048,
}
else:
speculative_config = {
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": 2048,
},
}

spec_llm = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp_size,
speculative_config=speculative_config,
max_model_len=2048,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
Expand Down
187 changes: 187 additions & 0 deletions tests/v1/spec_decode/test_ngram_eagle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from unittest import mock

import pytest
import torch

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
from vllm.platforms import current_platform
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"

NUM_SPECULATIVE_TOKENS_NGRAM = 5
NUM_SPECULATIVE_TOKENS_EAGLE = 3
PROMPT_LOOKUP_MIN = 2
PROMPT_LOOKUP_MAX = 5
DEVICE = current_platform.device_type


def _create_vllm_config(num_speculative_tokens_ngram: int,
num_speculative_tokens_eagle: int):
model_config = ModelConfig(model=model_dir,
runner="generate",
max_model_len=100)

# Choose model directory based on method
draft_model_dir = eagle_dir

speculative_config = SpeculativeConfig(
target_model_config=model_config,
target_parallel_config=ParallelConfig(),
model=draft_model_dir,
method="ngram-eagle",
num_speculative_tokens_per_method={
"ngram": num_speculative_tokens_ngram,
"eagle": num_speculative_tokens_eagle
},
prompt_lookup_max=PROMPT_LOOKUP_MAX,
prompt_lookup_min=PROMPT_LOOKUP_MIN,
)

vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
speculative_config=speculative_config,
device_config=DeviceConfig(device=current_platform.device_type),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig())

return vllm_config


def test_proposer_config():

vllm_config = _create_vllm_config(NUM_SPECULATIVE_TOKENS_NGRAM,
NUM_SPECULATIVE_TOKENS_EAGLE)

# ngram proposer
ngram_proposer = NgramProposer(vllm_config=vllm_config)
assert ngram_proposer.k == NUM_SPECULATIVE_TOKENS_NGRAM
assert ngram_proposer.min_n == PROMPT_LOOKUP_MIN
assert ngram_proposer.max_n == PROMPT_LOOKUP_MAX

# eagle proposer
eagle_proposer = EagleProposer(vllm_config=vllm_config,
device=current_platform.device_type)
assert eagle_proposer.num_speculative_tokens == NUM_SPECULATIVE_TOKENS_EAGLE


@pytest.mark.parametrize(
"test_value",
[
{
"sampled_token_ids": [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]],
# ngram draft is empty
"propose_ngram_draft_token_ids": [[]]
},
{
"sampled_token_ids": [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3]],
# ngram draft is not empty
"propose_ngram_draft_token_ids": [[4, 5, 6, 7, 8]]
}
])
@pytest.mark.parametrize("pp_size", [1, 2])
@mock.patch('vllm.v1.worker.gpu_model_runner.get_pp_group')
@mock.patch(
'vllm.v1.worker.gpu_model_runner.GPUModelRunner.propose_ngram_draft_token_ids'
)
@mock.patch('vllm.v1.worker.gpu_model_runner.EagleProposer.propose',
return_value=torch.tensor([[0, 1, 2]]))
@mock.patch('vllm.v1.worker.gpu_model_runner.EagleProposer.prepare_inputs',
return_value=(None, 0))
def test_propose_draft_token_ids(
mock_eagle_proposer_prepare_input,
mock_eagle_proposer_propose,
mock_propose_ngram_draft_token_ids,
mock_get_pp_group,
test_value,
pp_size,
):

vllm_config = _create_vllm_config(NUM_SPECULATIVE_TOKENS_NGRAM,
NUM_SPECULATIVE_TOKENS_EAGLE)

runner = GPUModelRunner(vllm_config, DEVICE)

# Setup mock for pp group to return the appropriate value for world size
mock_pp_group = mock.MagicMock()
mock_pp_group.world_size = pp_size
mock_get_pp_group.return_value = mock_pp_group

sampled_token_ids = test_value["sampled_token_ids"]
propose_ngram_draft_token_ids = test_value["propose_ngram_draft_token_ids"]

# with min matching ngram = 2, max matching ngram = 3
# we will find the prefix [1, 2, 3] in the history
# and speculate [4, 5, 6, 7, 8] for ngram
expected_ngram_proposals = [[4, 5, 6, 7, 8]]
expected_eagle_proposals = [[
i for i in range(NUM_SPECULATIVE_TOKENS_EAGLE)
]]
mock_propose_ngram_draft_token_ids \
.return_value = propose_ngram_draft_token_ids

# doesnt matter what this is for this test: START
scheduler_output = mock.MagicMock()
scheduler_output.total_num_scheduled_tokens = 1 + max(
vllm_config.speculative_config.
num_speculative_tokens_per_method["ngram"], vllm_config.
speculative_config.num_speculative_tokens_per_method["eagle"])
hidden_states = torch.randn(len(sampled_token_ids[0]), 4096)
sample_hidden_states = None
aux_hidden_states = None
spec_decode_metadata = mock.MagicMock()
spec_decode_metadata.num_draft_tokens = [
max(NUM_SPECULATIVE_TOKENS_NGRAM, NUM_SPECULATIVE_TOKENS_EAGLE)
]
common_attn_metadata = None
sampling_metadata = None

# set runner attributes that would normally be set during init
runner.supports_mm_inputs = False

mock_positions = mock.MagicMock()
mock_positions_instance = mock_positions.return_value
mock_positions_instance.gpu = torch.tensor([0])
runner.positions = mock_positions_instance

mock_input_ids = mock.MagicMock()
mock_input_ids_instance = mock_input_ids.return_value
mock_input_ids_instance.gpu = torch.tensor([0])
runner.input_ids = mock_input_ids_instance

mock_req_ids = mock.MagicMock()
mock_req_ids.return_value = ["0"]
# doesnt matter what this is for this test: END

final_draft = runner.propose_draft_token_ids(
scheduler_output=scheduler_output,
sampled_token_ids=sampled_token_ids,
sampling_metadata=sampling_metadata,
hidden_states=hidden_states,
sample_hidden_states=sample_hidden_states,
aux_hidden_states=aux_hidden_states,
spec_decode_metadata=spec_decode_metadata,
common_attn_metadata=common_attn_metadata,
)

# case 1: ngram draft is empty. Eagle draft is used
if sampled_token_ids == [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]:
assert final_draft == expected_eagle_proposals, \
"ngram-eagle should have selected eagle draft"
# case 2: ngram draft is not empty. Ngram draft is used
elif sampled_token_ids == [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3]]:
assert final_draft == expected_ngram_proposals, \
"ngram-eagle should have selected ngram draft"
else:
raise ValueError("unexpected sampled_token_ids")
Loading
Loading