Skip to content
1 change: 1 addition & 0 deletions examples/offline_inference/spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def parse_args():
parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true")
parser.add_argument("--request-id-prefix", type=str, default="")
return parser.parse_args()


Expand Down
46 changes: 32 additions & 14 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ def create_deterministic_logits(token_ids):
dtype=torch.int32,
device=device)
sampling_metadata = mock.MagicMock()
# Simulate mixed greedy and non-greedy requests
sampling_metadata.all_greedy = False
sampling_metadata.temperature = torch.tensor([-1, 0.7], device=device)

if attn_backend == "FLASH_ATTN_VLLM_V1":
attn_metadata_builder_cls, _ = get_attention_backend(
Expand All @@ -365,34 +368,47 @@ def create_deterministic_logits(token_ids):
proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder

result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)
# Call the method under test
result, result_probs = proposer.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)

assert result.shape == (batch_size, num_speculative_tokens)
assert result_probs.shape == (batch_size, num_speculative_tokens,
vocab_size)

# Create expected tokens based on our token pattern
if num_speculative_tokens == 1:
# Example for num_speculative_tokens=1:
# [[42], [60]]
expected_tokens = torch.tensor(
[[base_token_ids[0]], [base_token_ids[1]]], device=device)
[[base_token_ids[0]], [base_token_ids[1]]],
dtype=torch.int64,
device=device)
expected_probs = torch.zeros((batch_size, 1, vocab_size),
device=device)
for i, token_id in enumerate(base_token_ids):
expected_probs[i, 0, token_id] = 1.0
else:
# Example for num_speculative_tokens=3:
# [[42, 43, 44], [60, 61, 62]]
expected_tokens = torch.zeros((batch_size, num_speculative_tokens),
dtype=torch.int64,
device=device)
expected_probs = torch.zeros(
(batch_size, num_speculative_tokens, vocab_size), device=device)
for i in range(batch_size):
for j in range(num_speculative_tokens):
expected_tokens[i, j] = base_token_ids[i] + j
expected_probs[i, j, base_token_ids[i] + j] = 1.0

# Verify all tokens match our expectations
assert torch.equal(result, expected_tokens)
torch.testing.assert_close(result_probs, expected_probs)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -517,13 +533,15 @@ def create_deterministic_logits(token_ids, k: int):
sampling_metadata = mock.MagicMock()

# Propose draft tokens.
result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)
result, draft_probs = proposer.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)
assert result.shape == (batch_size, num_speculative_tokens)
assert draft_probs is None

# The tokens are expected to be consecutive integers starting
# from the base token IDs.
Expand Down
253 changes: 253 additions & 0 deletions tests/v1/spec_decode/test_scheduling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile

import pytest
import torch

from tests.v1.worker.test_gpu_model_runner import _schedule_new_request
from vllm.config import VllmConfig
from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment,
initialize_model_parallel)
from vllm.engine.arg_utils import EngineArgs
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.engine.core import get_kv_cache_config
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"


@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
# So we can share the DraftModelProposer between tests
return False


@pytest.fixture(scope="class")
def monkeyclass():
with pytest.MonkeyPatch.context() as mp:
yield mp


@pytest.fixture(scope="class")
def spec_decode_vllm_config_and_env_setup(monkeyclass: pytest.MonkeyPatch):
with monkeyclass.context() as m:
m.setenv("VLLM_USE_V1", "1")
vllm_config = EngineArgs(model=model_dir,
max_model_len=256,
cuda_graph_sizes=[1, 2, 4],
gpu_memory_utilization=0.8,
speculative_config={
"model": eagle_dir,
"method": "eagle",
"num_speculative_tokens": 2,
}).create_engine_config()
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend="nccl",
)
initialize_model_parallel(1, 1)
yield vllm_config
cleanup_dist_env_and_memory()


@pytest.fixture(scope="class")
def mock_spec_decode_model_runner(
spec_decode_vllm_config_and_env_setup: VllmConfig):
model_runner = GPUModelRunner(spec_decode_vllm_config_and_env_setup,
torch.device("cuda"))
model_runner.load_model()
kv_cache_spec = model_runner.get_kv_cache_spec()

kv_cache_config = get_kv_cache_config(
spec_decode_vllm_config_and_env_setup, kv_cache_spec, 1024**3) # 1GB
model_runner.initialize_kv_cache(kv_cache_config)
yield model_runner


class TestSpecDecodeScheduling:

def test_spec_decode_partial_scheduling(
self, mock_spec_decode_model_runner: GPUModelRunner):
"""Make sure we don't crash when the scheduler schedules only a subset
of the requests.

Four iterations:
1. Schedule both req1 (w/ 0 draft) and req2 (w/ 0 draft)
2. Schedule only req1 (w/ 1 draft)
3. Schedule both req1 (w/ 1 draft) and req2 (w/ 2 draft)
4. Terminate req1 and req2
"""
# Schedule both req1 and req2 on the first iteration
scheduler_output = _schedule_new_request("req1", "req2")
mock_spec_decode_model_runner.execute_model(scheduler_output)

# Only schedule req1 on the second iteration
cached_req_data = CachedRequestData(
req_ids=["req1"],
resumed_from_preemption=[False],
new_token_ids=[[3]],
new_block_ids=[([], )],
num_computed_tokens=[3],
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=cached_req_data,
num_scheduled_tokens={"req1": 2},
total_num_scheduled_tokens=2,
scheduled_spec_decode_tokens={"req1": [1001]},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[0],
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
mock_spec_decode_model_runner.execute_model(scheduler_output)

# Schedule both req1 and req2 on the third iteration
cached_req_data = CachedRequestData(
req_ids=["req1", "req2"],
resumed_from_preemption=[False, False],
new_token_ids=[[10], [11]],
new_block_ids=[([], ), ([], )],
num_computed_tokens=[4, 3],
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=cached_req_data,
num_scheduled_tokens={
"req1": 2,
"req2": 3
},
total_num_scheduled_tokens=5,
scheduled_spec_decode_tokens={
"req1": [1001],
"req2": [2001, 2002]
},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[0],
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
mock_spec_decode_model_runner.execute_model(scheduler_output)

# Terminate both req1 and req2
cached_req_data = CachedRequestData(
req_ids=[],
resumed_from_preemption=[],
new_token_ids=[],
new_block_ids=[],
num_computed_tokens=[],
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=cached_req_data,
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[0],
finished_req_ids={"req1", "req2"},
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
mock_spec_decode_model_runner.execute_model(scheduler_output)

def test_spec_decode_preemption_scheduling(
self, mock_spec_decode_model_runner: GPUModelRunner):
"""Make sure we don't crash when the scheduler preempts a request.

Four iterations:
1. Schedule req1 (w/ 0 draft) and req2 (w/ 0 draft)
2. Schedule req1 (w/ 1 draft) and preempt req2
3. Schedule req1 (w/ 1 draft) and resume req2 (w/ 2 draft)
4. Terminate req1 and req2
"""
# Schedule both req1 and req2 on the first iteration
scheduler_output = _schedule_new_request("req1", "req2")
mock_spec_decode_model_runner.execute_model(scheduler_output)

# Only schedule req1 on the second iteration
cached_req_data = CachedRequestData(
req_ids=["req1"],
resumed_from_preemption=[False],
new_token_ids=[[3]],
new_block_ids=[([], )],
num_computed_tokens=[3],
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=cached_req_data,
num_scheduled_tokens={"req1": 2},
total_num_scheduled_tokens=2,
scheduled_spec_decode_tokens={"req1": [1001]},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[0],
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
mock_spec_decode_model_runner.execute_model(scheduler_output)

# Schedule both req1 and req2 on the third iteration
cached_req_data = CachedRequestData(
req_ids=["req1", "req2"],
resumed_from_preemption=[False, True],
new_token_ids=[[10], [11]],
new_block_ids=[([], ), ([0], )],
num_computed_tokens=[4, 0],
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=cached_req_data,
num_scheduled_tokens={
"req1": 2,
"req2": 6
},
total_num_scheduled_tokens=8,
scheduled_spec_decode_tokens={
"req1": [1001],
"req2": [2001, 2002]
},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[0],
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
mock_spec_decode_model_runner.execute_model(scheduler_output)

# Terminate both req1 and req2
cached_req_data = CachedRequestData(
req_ids=[],
resumed_from_preemption=[],
new_token_ids=[],
new_block_ids=[],
num_computed_tokens=[],
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=cached_req_data,
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[0],
finished_req_ids={"req1", "req2"},
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
mock_spec_decode_model_runner.execute_model(scheduler_output)
2 changes: 1 addition & 1 deletion tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[0],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
Expand Down
6 changes: 6 additions & 0 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1989,6 +1989,12 @@ class SpeculativeConfig:
speculative_token_tree: Optional[str] = None
"""Specifies the tree structure for speculative token generation.
"""
enable_draft_probs: bool = True
"""Whether to use draft probs for speculative decoding. Using draft probs
always increases the acceptance rate but increases sampling overhead.
For small models and/or low temperatures payloads, it may be beneficial to
disable this. Disabling falls back to greedy sampling for the draft tokens.
"""
# required configuration params passed from engine
target_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the target model."""
Expand Down
Loading