diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 496ff85f7562..b21fc9c904cc 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -995,6 +995,7 @@ def create_deterministic_logits(token_ids): proposer.draft_attn_groups = [mock_attn_group] result = proposer.propose( + num_speculative_tokens=num_speculative_tokens, target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, @@ -1150,6 +1151,7 @@ def create_deterministic_logits(token_ids, k: int): # Propose draft tokens. result = proposer.propose( + num_speculative_tokens=num_speculative_tokens, target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, diff --git a/tests/v1/spec_decode/test_extract_hidden_states.py b/tests/v1/spec_decode/test_extract_hidden_states.py index 27b2a53c1849..1cac0ee63188 100644 --- a/tests/v1/spec_decode/test_extract_hidden_states.py +++ b/tests/v1/spec_decode/test_extract_hidden_states.py @@ -258,6 +258,7 @@ def test_propose(): # Call propose draft_tokens = proposer.propose( + num_speculative_tokens=1, sampled_token_ids=sampled_token_ids, target_hidden_states=target_hidden_states, common_attn_metadata=common_attn_metadata, @@ -324,6 +325,7 @@ def test_propose_different_layer_counts(num_hidden_layers): ).unsqueeze(-1) draft_tokens = proposer.propose( + num_speculative_tokens=1, sampled_token_ids=sampled_token_ids, target_hidden_states=target_hidden_states, common_attn_metadata=common_attn_metadata, diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py index 0a48b0e7b98c..9fa228fb05d5 100644 --- a/tests/v1/spec_decode/test_mtp.py +++ b/tests/v1/spec_decode/test_mtp.py @@ -204,6 +204,7 @@ def create_deterministic_logits(batch_size, vocab_size, token_offset): # Run propose result = proposer.propose( + num_speculative_tokens=num_speculative_tokens, target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index 7d2a07ddcec7..459edddd1c2e 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -81,6 +81,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # No match. token_ids_cpu = np.array([[1, 2, 3, 4, 5]]) result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( + num_speculative_tokens=2, sampled_token_ids=[[0]], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -90,6 +91,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # No match for 4-gram. token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose( + num_speculative_tokens=2, sampled_token_ids=[[0]], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -99,6 +101,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # No match for 4-gram but match for 3-gram. token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( + num_speculative_tokens=2, sampled_token_ids=[[0]], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -109,6 +112,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # In this case, the proposer should return the 4-gram match. token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]]) result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( + num_speculative_tokens=2, sampled_token_ids=[[0]], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -118,6 +122,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # Match for 2-gram and 3-gram, but not 4-gram. token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]]) result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose( + num_speculative_tokens=2, sampled_token_ids=[[0]], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -127,6 +132,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # Multiple 3-gram matched, but always pick the first one. token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]]) result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose( + num_speculative_tokens=2, sampled_token_ids=[[0]], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -136,6 +142,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # check empty input token_ids_cpu = np.array([[]]) result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( + num_speculative_tokens=2, sampled_token_ids=[[0]], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -147,6 +154,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # second request has 3 tokens and no match. Padded with -1 for max len 5 token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]]) result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( + num_speculative_tokens=2, sampled_token_ids=[[0], [1]], num_tokens_no_spec=np.array([5, 3]), token_ids_cpu=token_ids_cpu, @@ -166,6 +174,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: num_tokens_no_spec = np.array([5, 3, 5], dtype=np.int32) sampled_token_ids = [[2], [], [8]] # Empty list for request 1 simulates prefill result = proposer.propose( + num_speculative_tokens=2, sampled_token_ids=sampled_token_ids, num_tokens_no_spec=num_tokens_no_spec, token_ids_cpu=token_ids_cpu, @@ -195,6 +204,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: input_2[:3] = [4, 5, 6] token_ids_cpu = np.array([input_1, input_2]) result = ngram_proposer.propose( + num_speculative_tokens=2, sampled_token_ids=[[0], [1]], num_tokens_no_spec=np.array([len(input_1), 3]), token_ids_cpu=token_ids_cpu, diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index f1fda9afd318..114c61dc6f03 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -64,6 +64,41 @@ RejectionSampleMethod = Literal["strict", "probabilistic", "synthetic"] +@config +class DynamicSpeculativeConfig: + """A mapping from batch size to optimal number of drafts to use for that + batch size. This is used to dynamically adjust the number of drafts used + based on the current batch size.""" + + is_online: bool = False + """Whether the statistics are updated online or not during inference.""" + + batch_stats: dict[int, dict[int, float]] | None = None + """ + Batch statistics for different batch sizes and number of drafts. + The structure is as follows: + { + batch_size: { + num_drafts: itl (i.e., inter token latency in ms) + } + } + + e.g., + { + 1: { 0: 6.87, 3: 9.41, 5: 10.8}, + 4: { 0: 7.3, 3: 9.95, 5: 11.59}, + } + + where bs 1 at K=3 has itl 9.41ms. K=0 means no speculative decoding. + """ + + max_num_speculative_tokens: int | None = None + """Maximum number of speculative tokens supported in the statistics.""" + + acceptance_rate_per_pos: list[float] | None = None + """Acceptance rate per position on an offline dataset.""" + + @config class SpeculativeConfig: """Configuration for speculative decoding.""" @@ -150,6 +185,12 @@ class SpeculativeConfig: target_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore """The parallel configuration for the target model.""" + # dynamic speculative decoding control + dynamic_config_path: str | None = None + """Path to config file for dynamic speculative decoding, if provided.""" + dynamic_config: SkipValidation[DynamicSpeculativeConfig] | None = None + """Loaded dynamic speculative config, populated from dynamic_config_path.""" + # params generated in the post-init stage draft_model_config: SkipValidation[ModelConfig] = None # type: ignore """The configuration of the draft model initialized internal.""" @@ -628,6 +669,16 @@ def __post_init__(self): self.target_parallel_config, self.draft_tensor_parallel_size ) ) + + # load DynamicSpeculativeConfig: maybe use get_hf_file_to_dict() later + if self.dynamic_config_path is not None: + import json + + with open(self.dynamic_config_path) as f: + data = json.load(f) + + self.dynamic_config = DynamicSpeculativeConfig(**data) + return self def _validate_suffix_decoding(self): diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py index 0b3958dbcf5a..381b758870ff 100644 --- a/vllm/v1/core/sched/async_scheduler.py +++ b/vllm/v1/core/sched/async_scheduler.py @@ -18,6 +18,10 @@ def __init__(self, *args, **kwargs) -> None: def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None: super()._update_after_schedule(scheduler_output) spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens + # Use the latest num of scheduled draft tokens in next step as placeholder. + self._spec_token_placeholders = [ + -1 + ] * scheduler_output.num_spec_tokens_to_schedule for req_id in scheduler_output.num_scheduled_tokens: request = self.requests[req_id] if request.is_prefill_chunk: diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index bdb97decadfe..08e26476ac34 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -238,6 +238,10 @@ class SchedulerOutput: # preventing stale NaN/data from corrupting attention or SSM computation. new_block_ids_to_zero: list[int] | None = None + # Dynamic speculative decoding: optimal K chosen by scheduler. + # Number of spec tokens to schedule for the next step. + num_spec_tokens_to_schedule: int = 0 + @classmethod def make_empty(cls) -> "SchedulerOutput": return cls( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c28a5d18ae77..18f51691a160 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -57,6 +57,7 @@ from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus, StreamingUpdate +from vllm.v1.spec_decode.dynamic.manager import DynamicSpeculativeDecodingManager from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.utils import record_function_or_nullcontext @@ -213,8 +214,15 @@ def __init__( speculative_config = vllm_config.speculative_config self.use_eagle = False self.num_spec_tokens = self.num_lookahead_tokens = 0 + self.dynamic_sd_manager: DynamicSpeculativeDecodingManager | None = None if speculative_config: self.num_spec_tokens = speculative_config.num_speculative_tokens + if speculative_config.dynamic_config: + self.dynamic_sd_manager = DynamicSpeculativeDecodingManager( + speculative_config.dynamic_config, + self.scheduler_config.max_num_seqs, + self.num_spec_tokens, + ) if speculative_config.use_eagle(): self.use_eagle = True self.num_lookahead_tokens = self.num_spec_tokens @@ -904,6 +912,13 @@ def schedule(self) -> SchedulerOutput: else None ) + # Dynamic speculative decoding: compute optimal K + num_spec_tokens_to_schedule = self.num_spec_tokens + if self.dynamic_sd_manager is not None and len(num_scheduled_tokens) > 0: + num_spec_tokens_to_schedule = self.dynamic_sd_manager.step( + len(num_scheduled_tokens) + ) + scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -920,6 +935,7 @@ def schedule(self) -> SchedulerOutput: finished_req_ids=self.finished_req_ids, free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), new_block_ids_to_zero=new_block_ids_to_zero, + num_spec_tokens_to_schedule=num_spec_tokens_to_schedule, ) # NOTE(Kuntai): this function is designed for multiple purposes: @@ -1975,12 +1991,15 @@ def make_spec_decoding_stats( num_invalid_spec_tokens: dict[str, int] | None, request_id: str, ) -> SpecDecodingStats | None: + if num_invalid_spec_tokens: + num_draft_tokens -= num_invalid_spec_tokens.get(request_id, 0) + if self.dynamic_sd_manager is not None and num_draft_tokens: + self.dynamic_sd_manager.observe_draft(num_draft_tokens, num_accepted_tokens) + if not self.log_stats or not num_draft_tokens: return None if spec_decoding_stats is None: spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens) - if num_invalid_spec_tokens: - num_draft_tokens -= num_invalid_spec_tokens.get(request_id, 0) spec_decoding_stats.observe_draft( num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens ) diff --git a/vllm/v1/spec_decode/dynamic/__init__.py b/vllm/v1/spec_decode/dynamic/__init__.py new file mode 100644 index 000000000000..208f01a7cb5e --- /dev/null +++ b/vllm/v1/spec_decode/dynamic/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm/v1/spec_decode/dynamic/generate_config.py b/vllm/v1/spec_decode/dynamic/generate_config.py new file mode 100644 index 000000000000..c20aed5ed5cf --- /dev/null +++ b/vllm/v1/spec_decode/dynamic/generate_config.py @@ -0,0 +1,461 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import dataclasses +import json +import pprint +import time +from pathlib import Path +from typing import Any + +from transformers import AutoTokenizer + +from vllm import LLM +from vllm.benchmarks.datasets import add_dataset_parser, get_samples +from vllm.benchmarks.sweep.param_sweep import ParameterSweep +from vllm.benchmarks.sweep.serve import SweepServeArgs, run_main +from vllm.config.speculative import DynamicSpeculativeConfig +from vllm.sampling_params import SamplingParams +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.v1.metrics.reader import Counter, Vector + + +def build_serve_params( + method, + draft_dir, + tp, + num_speculative_tokens_list, + prompt_lookup_max, + prompt_lookup_min, +): + """Build serve parameter sweep for vanilla + speculative decode configs. + + Each entry becomes a separate server configuration in the sweep. + The sweep framework starts/stops the server for each serve config. + """ + records: list[dict[str, object]] = [] + + # Vanilla config (no speculative decoding) + records.append({"_benchmark_name": "vanilla"}) + + # Speculative decoding configs with varying num_speculative_tokens + if method == "ngram": + for k in num_speculative_tokens_list: + records.append( + { + "_benchmark_name": f"ngram-k-{k}", + "speculative_config": { + "method": "ngram", + "num_speculative_tokens": k, + "prompt_lookup_max": prompt_lookup_max, + "prompt_lookup_min": prompt_lookup_min, + }, + } + ) + elif method in ("eagle", "eagle3"): + for k in num_speculative_tokens_list: + records.append( + { + "_benchmark_name": f"{method}-k-{k}", + "speculative_config": { + "method": method, + "model": draft_dir, + "num_speculative_tokens": k, + "draft_tensor_parallel_size": tp, + }, + } + ) + elif method == "mtp": + for k in num_speculative_tokens_list: + records.append( + { + "_benchmark_name": f"mtp-k-{k}", + "speculative_config": { + "method": "mtp", + "num_speculative_tokens": k, + }, + } + ) + + return ParameterSweep.from_records(records) + + +def build_bench_params(batch_size_list, num_batches): + """Build benchmark parameter sweep for different concurrencies. + + Each entry varies max_concurrency (batch size) and num_prompts. + """ + records = [] + for bs in batch_size_list: + num_prompts = num_batches * bs + records.append( + { + "_benchmark_name": f"bs-{bs}", + "max_concurrency": bs, + "num_prompts": num_prompts, + } + ) + return ParameterSweep.from_records(records) + + +def parse_itl_from_dataframe(result_df): + """Parse ITL from sweep result DataFrame into batch_stats format. + + Returns: + batch_stats: dict of {batch_size: {num_drafts: median_itl_ms}} + where num_drafts=0 corresponds to vanilla (no speculation). + """ + batch_stats: dict[int, dict[int, float]] = {} + for _, row in result_df.iterrows(): + bs = int(row["max_concurrency"]) + + # Determine k (num speculative tokens) from speculative_config. + # For vanilla rows, speculative_config is NaN (missing from serve + # params); for spec decode rows, it's a dict or JSON string. + spec_config = row.get("speculative_config") + if isinstance(spec_config, dict): + k = int(spec_config["num_speculative_tokens"]) + elif isinstance(spec_config, str): + k = int(json.loads(spec_config)["num_speculative_tokens"]) + else: + k = 0 # vanilla (NaN or None) + + if bs not in batch_stats: + batch_stats[bs] = {} + batch_stats[bs][k] = row["median_itl_ms"] + + return batch_stats + + +def run_profiling_sweep(args): + """Run profiling benchmarks using vllm bench sweep serve.""" + # Base serve command (static params shared across all serve configs) + serve_cmd = [ + "vllm", + "serve", + args.model_dir, + "--gpu-memory-utilization", + "0.95", + "--max-num-seqs", + str(args.max_vllm_batch_size), + "--tensor-parallel-size", + str(args.tp), + "--enable-chunked-prefill", + "--no-enable-prefix-caching", + ] + + # Base bench command (static params shared across all bench configs) + bench_cmd = [ + "vllm", + "bench", + "serve", + "--model", + args.model_dir, + "--backend", + "openai-chat", + "--endpoint", + "/v1/chat/completions", + "--dataset-name", + args.dataset_name, + "--dataset-path", + args.dataset_path, + f"--temperature={args.temp}", + f"--top-p={args.top_p}", + f"--top-k={args.top_k}", + ] + + # Build parameter sweeps + serve_params = build_serve_params( + method=args.method, + draft_dir=args.draft_dir, + tp=args.tp, + num_speculative_tokens_list=args.num_speculative_tokens_list, + prompt_lookup_max=args.prompt_lookup_max, + prompt_lookup_min=args.prompt_lookup_min, + ) + + bench_params = build_bench_params( + batch_size_list=args.batch_size_list, + num_batches=args.num_batches, + ) + + # Run the sweep + sweep_args = SweepServeArgs( + serve_cmd=serve_cmd, + bench_cmd=bench_cmd, + after_bench_cmd=[], + show_stdout=True, + serve_params=serve_params, + bench_params=bench_params, + output_dir=Path(args.result_dir), + num_runs=1, + dry_run=False, + resume=False, + link_vars=[], + server_ready_timeout=600, + experiment_name=f"{args.method}-{args.num_spec_tokens}", + ) + + result_df = run_main(sweep_args) + return result_df + + +def get_acceptance_rate_per_pos(args): + """Get acceptance rate per position.""" + + tokenizer = AutoTokenizer.from_pretrained(args.model_dir) + prompts = get_samples(args, tokenizer) + llm_prompts: list[Any] + if args.enable_multimodal_chat: + llm_prompts = [p.prompt for p in prompts] + else: + # add_special_tokens is False to avoid adding bos twice + # when using chat templates + llm_prompts = [ + { + "prompt_token_ids": tokenizer.encode( + prompt.prompt, add_special_tokens=False + ), + "multi_modal_data": prompt.multi_modal_data, + } + for prompt in prompts + ] + if args.method == "eagle" or args.method == "eagle3": + eagle_dir = args.eagle_dir + if args.method == "eagle" and eagle_dir is None: + eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" + + elif args.method == "eagle3" and eagle_dir is None: + eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" + speculative_config = { + "method": args.method, + "model": eagle_dir, + "num_speculative_tokens": args.num_spec_tokens, + "disable_padded_drafter_batch": args.disable_padded_drafter_batch, + "parallel_drafting": args.parallel_drafting, + } + elif args.method == "ngram": + speculative_config = { + "method": "ngram", + "num_speculative_tokens": args.num_spec_tokens, + "prompt_lookup_max": args.prompt_lookup_max, + "prompt_lookup_min": args.prompt_lookup_min, + } + elif args.method == "draft_model": + assert args.draft_model is not None and args.draft_model != "" + speculative_config = { + "method": args.method, + "model": args.draft_model, + "num_speculative_tokens": args.num_spec_tokens, + "enforce_eager": args.enforce_eager, + "max_model_len": args.max_model_len, + "parallel_drafting": args.parallel_drafting, + } + elif args.method == "mtp": + speculative_config = { + "method": "mtp", + "num_speculative_tokens": args.num_spec_tokens, + } + else: + raise ValueError(f"unknown method: {args.method}") + + llm = LLM( + model=args.model_dir, + trust_remote_code=True, + tensor_parallel_size=args.tp, + enable_chunked_prefill=args.enable_chunked_prefill, + enforce_eager=args.enforce_eager, + gpu_memory_utilization=args.gpu_memory_utilization, + speculative_config=speculative_config, + disable_log_stats=False, + max_model_len=args.max_model_len, + limit_mm_per_prompt={"image": 5}, + disable_chunked_mm_input=True, + max_num_seqs=args.max_vllm_batch_size, + allowed_local_media_path=args.allowed_local_media_path, + ) + + sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) + if args.backend == "openai-chat": + _ = llm.chat(llm_prompts, sampling_params=sampling_params) + else: + _ = llm.generate( + llm_prompts, + sampling_params=sampling_params, + ) + + metrics = llm.get_metrics() + + num_drafts = 0 + num_draft_tokens = 0 + num_accepted_tokens = 0 + acceptance_counts = [0] * args.num_spec_tokens + for metric in metrics: + if metric.name == "vllm:spec_decode_num_drafts": + assert isinstance(metric, Counter) + num_drafts += metric.value + elif metric.name == "vllm:spec_decode_num_draft_tokens": + assert isinstance(metric, Counter) + num_draft_tokens += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens": + assert isinstance(metric, Counter) + num_accepted_tokens += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": + assert isinstance(metric, Vector) + for pos in range(len(metric.values)): + acceptance_counts[pos] += metric.values[pos] + + acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1 + print(f"mean acceptance length: {acceptance_length:.2f}") + print("-" * 50) + + # print acceptance at each token position + acceptance_rate_per_pos = [] + for i in range(len(acceptance_counts)): + acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0 + print(f"acceptance at token {i}: {acceptance_rate:.2f}") + acceptance_rate_per_pos.append(acceptance_rate) + + return acceptance_length, acceptance_rate_per_pos + + +def main(): + parser = FlexibleArgumentParser() + add_dataset_parser(parser) + + parser.add_argument("--model-dir", type=str, default=None) + parser.add_argument("--draft-dir", type=str, default=None) + parser.add_argument( + "--method", + type=str, + default="eagle", + choices=["ngram", "eagle", "eagle3", "mtp"], + ) + parser.add_argument( + "--num-speculative-tokens-list", nargs="*", type=int, default=[1, 3, 5] + ) + parser.add_argument( + "--batch-size-list", nargs="*", type=int, default=[1, 4, 16, 64, 256] + ) + parser.add_argument( + "--max-vllm-batch-size", + type=int, + help="Max vllm server batch size (max concurrency)", + ) + parser.add_argument("--backend", type=str, default="openai") + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--enforce-eager", action="store_true") + parser.add_argument("--enable-chunked-prefill", action="store_true") + parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) + parser.add_argument("--disable-padded-drafter-batch", action="store_true") + parser.add_argument("--result-dir", type=str, default="./log/dynamic_sd") + parser.add_argument("--prompt-lookup-max", type=int, default=5) + parser.add_argument("--prompt-lookup-min", type=int, default=2) + parser.add_argument("--max-model-len", type=int, default=16384) + parser.add_argument("--temp", type=float, default=0) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--top-k", type=int, default=-1) + parser.add_argument("--output-len", type=int, default=256) + parser.add_argument("--parallel-drafting", action="store_true") + parser.add_argument("--allowed-local-media-path", type=str, default="") + parser.add_argument( + "--num-batches", + type=int, + default=20, + help="Number of batches to run for each benchmark.", + ) + parser.add_argument("--custom-mm-prompts", action="store_true") + + args = parser.parse_args() + args.enable_chunked_prefill = True + args.enforce_eager = False + args.print_output = False + args.num_spec_tokens = max(args.num_speculative_tokens_list) + args.eagle_dir = args.draft_dir + args.result_dir = ( + f"{args.result_dir}/tp-{args.tp}_temp-{args.temp}" + f"_top_p-{args.top_p}_top_k-{args.top_k}" + f"/{args.dataset_path}/" + ) + + assert args.max_vllm_batch_size == max(args.batch_size_list), ( + "max_vllm_batch_size must be equal to max of batch_size_list" + ) + + pprint.pprint(vars(args)) + start = time.time() + + # Step 1: get acceptance_rate_per_pos + acceptance_length, acceptance_rate_per_pos = get_acceptance_rate_per_pos(args) + print(f"Acceptance length: {acceptance_length}") + print(f"Acceptance rate per position: {acceptance_rate_per_pos}") + print("✅ Step 1: obtained acceptance rate per position.") + + # Step 2: generate benchmark data using vllm bench sweep serve + # This runs the Cartesian product of: + # serve_params: [vanilla, {method}-k-1, {method}-k-3, {method}-k-5] + # bench_params: [bs-1, bs-4, bs-16, bs-64, bs-256] + # The sweep framework handles server start/stop for each serve config. + result_df = run_profiling_sweep(args) + print(f"✅ Step 2: benchmark data generated for vanilla and {args.method}.") + + # Step 3: parse batch_stats from benchmark data + batch_stats = parse_itl_from_dataframe(result_df) + print("✅ Step 3: parsed batch statistics from benchmark data.") + print(json.dumps(batch_stats, indent=4)) + + # Step 4: Save DynamicSpeculativeConfig to a json file + dynamic_config = DynamicSpeculativeConfig( + is_online=False, + max_num_speculative_tokens=len(acceptance_rate_per_pos), + acceptance_rate_per_pos=acceptance_rate_per_pos, + batch_stats=batch_stats, + ) + + config_path = f"{args.result_dir}/dynamic_speculative_config.json" + with open(config_path, "w") as f: + json.dump(dataclasses.asdict(dynamic_config), f, indent=4) + + print(f"✅ Step 4: config saved to {config_path}") + + end = time.time() + print(f"Total time taken: {end - start:.2f} seconds") + + +""" +time python3 vllm/v1/spec_decode/dynamic/generate_config.py \ + --method eagle \ + --model-dir 'meta-llama/Llama-3.1-8B-Instruct' \ + --draft-dir 'yuhuili/EAGLE-LLaMA3.1-Instruct-8B' \ + --tp 1 \ + --temp 0 \ + --top-p 1.0 \ + --top-k -1 \ + --max-vllm-batch-size 256 \ + --batch-size-list 1 4 16 64 256 \ + --num-speculative-tokens-list 1 3 5 \ + --num-batches 20 \ + --dataset-name hf \ + --dataset-path 'philschmid/mt-bench' \ + --no-oversample \ + --result-dir './log/dynamic_sd_test' + +# shorter version: +time python3 vllm/v1/spec_decode/dynamic/generate_config.py \ + --method eagle \ + --model-dir 'meta-llama/Llama-3.1-8B-Instruct' \ + --draft-dir 'yuhuili/EAGLE-LLaMA3.1-Instruct-8B' \ + --tp 1 \ + --temp 0 \ + --top-p 1.0 \ + --top-k -1 \ + --max-vllm-batch-size 256 \ + --batch-size-list 1 256 \ + --num-speculative-tokens-list 1 5 \ + --num-batches 5 \ + --dataset-name hf \ + --dataset-path 'philschmid/mt-bench' \ + --no-oversample \ + --result-dir './log/dynamic_sd_test_short' +""" +if __name__ == "__main__": + main() diff --git a/vllm/v1/spec_decode/dynamic/manager.py b/vllm/v1/spec_decode/dynamic/manager.py new file mode 100644 index 000000000000..6c49fea29cad --- /dev/null +++ b/vllm/v1/spec_decode/dynamic/manager.py @@ -0,0 +1,200 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.config.speculative import DynamicSpeculativeConfig +from vllm.v1.spec_decode.metrics import SpecDecodingStats + + +class DynamicSpeculativeDecodingManager: + """A mapping from batch size to optimal number of drafts to use for that + batch size. This is used to dynamically adjust the number of drafts used + based on the current batch size.""" + + _optimal_num_speculative_tokens: dict[int, int] + + def __init__( + self, + dynamic_config: DynamicSpeculativeConfig, + vllm_max_batch_size: int, + vllm_num_speculative_tokens: int, + warmup_steps: int = 10, # TODO: make this configurable + ): + self.dynamic_config = dynamic_config + self.vllm_max_batch_size = vllm_max_batch_size + self.vllm_num_speculative_tokens = vllm_num_speculative_tokens + + assert dynamic_config.batch_stats is not None, ( + "batch_stats is required for dynamic speculative decoding" + ) + assert dynamic_config.max_num_speculative_tokens is not None, ( + "max_num_speculative_tokens is required for dynamic speculative decoding" + ) + assert dynamic_config.acceptance_rate_per_pos is not None, ( + "acceptance_rate_per_pos is required for dynamic speculative decoding" + ) + + self.batch_stats: dict[int, dict[int, float]] = dynamic_config.batch_stats + self.max_num_speculative_tokens: int = dynamic_config.max_num_speculative_tokens + self.acceptance_rate_per_pos: list[float] = ( + dynamic_config.acceptance_rate_per_pos + ) + self.available_batch_sizes = sorted(dynamic_config.batch_stats.keys()) + self.steps = 0 + self.warmup_steps = warmup_steps + + # Cumulative stats for online acceptance rate updates + self.stats = SpecDecodingStats.new(vllm_num_speculative_tokens) + + # Sanity check + assert ( + vllm_num_speculative_tokens <= dynamic_config.max_num_speculative_tokens + ), "vllm_num_speculative_tokens must be <= max_num_speculative_tokens" + + assert dynamic_config.max_num_speculative_tokens == len( + dynamic_config.acceptance_rate_per_pos + ), "max_num_speculative_tokens != len(acceptance_rate_per_pos)" + assert dynamic_config.max_num_speculative_tokens > 0, ( + "max_num_speculative_tokens must be > 0" + ) + assert all(0.0 <= a <= 1.0 for a in dynamic_config.acceptance_rate_per_pos), ( + "all acceptance_rate_per_pos values must be in [0.0, 1.0]" + ) + assert 1 in dynamic_config.batch_stats, ( + f"BS 1 not found in {dynamic_config.batch_stats.keys()}" + ) + assert vllm_max_batch_size in dynamic_config.batch_stats, ( + f"max BS not found in {dynamic_config.batch_stats.keys()}" + ) + + for bs in self.available_batch_sizes: + assert bs > 0 + assert 0 in dynamic_config.batch_stats[bs], ( + f"batch size {bs} must have draft 0 stats" + ) + assert 1 in dynamic_config.batch_stats[bs], ( + f"batch size {bs} must have draft 1 stats" + ) + assert sorted(dynamic_config.batch_stats[bs].keys()) == list( + dynamic_config.batch_stats[bs].keys() + ), f"batch size {bs} draft keys must be sorted" + + self.update_optimal_num_speculative_tokens() + + def step(self, batch_size: int) -> int: + self.steps += 1 + if self.should_update(): + acceptance_rate_per_pos = self.compute_acceptance_rate_per_pos() + self.update_acceptance_rate_per_pos(acceptance_rate_per_pos) + + optimal_num_speculative_tokens = self.get_optimal_num_speculative_tokens( + batch_size + ) + + return optimal_num_speculative_tokens + + def compute_acceptance_rate_per_pos(self) -> list[float]: + acceptance_rate_per_pos = [] + for i in range(self.vllm_num_speculative_tokens): + if self.stats.num_draft_tokens_per_pos[i] == 0: + acceptance_rate = 0.0 + else: + acceptance_rate = ( + self.stats.num_accepted_tokens_per_pos[i] + / self.stats.num_draft_tokens_per_pos[i] + ) + + acceptance_rate_per_pos.append(acceptance_rate) + + return acceptance_rate_per_pos + + def should_update(self) -> bool: + # making this a separate function for easier overriding or extension + return self.steps > self.warmup_steps + + def get_optimal_num_speculative_tokens(self, batch_size: int) -> int: + assert batch_size > 0, "batch_size must be > 0" + assert batch_size <= self.vllm_max_batch_size, ( + "batch_size must be <= vllm_max_batch_size" + ) + return self._optimal_num_speculative_tokens[batch_size] + + def update_optimal_num_speculative_tokens(self): + self._optimal_num_speculative_tokens = { + bs: self._compute_optimal_num_speculative_tokens(bs) + for bs in range(1, self.vllm_max_batch_size + 1) + } + + def update_acceptance_rate_per_pos(self, acceptance_rate_per_pos: list[float]): + self.acceptance_rate_per_pos = acceptance_rate_per_pos + self.dynamic_config.acceptance_rate_per_pos = acceptance_rate_per_pos + self.update_optimal_num_speculative_tokens() + + def _get_batch_stats(self, batch_size: int) -> dict: + if batch_size not in self.batch_stats: + # find the nearest batch size smaller and bigger than the given batch size + # and return the weighted avg of their stats + + smaller_bs_list = [ + bs for bs in self.available_batch_sizes if bs < batch_size + ] + smaller_bs = ( + max(smaller_bs_list) + if len(smaller_bs_list) + else self.available_batch_sizes[0] + ) + larger_bs_list = [ + bs for bs in self.available_batch_sizes if bs > batch_size + ] + larger_bs = ( + min(larger_bs_list) + if len(larger_bs_list) + else self.available_batch_sizes[-1] + ) + + smaller_bs_stat = self.batch_stats[smaller_bs] + larger_bs_stat = self.batch_stats[larger_bs] + + if larger_bs == smaller_bs: + return self.batch_stats[smaller_bs] + + ratio = (batch_size - smaller_bs) / (larger_bs - smaller_bs) + + avg_stat: dict[int, float] = {} + for k in smaller_bs_stat: + avg_stat[k] = smaller_bs_stat[k] + ratio * ( + larger_bs_stat[k] - smaller_bs_stat[k] + ) + + return avg_stat + else: + return self.batch_stats[batch_size] + + def _get_itl(self, batch_stats, num_drafts: int) -> float: + if num_drafts in batch_stats: + return batch_stats[num_drafts] + else: + lower_num_draft = max(k for k in batch_stats if k < num_drafts) + upper_num_draft = min(k for k in batch_stats if k > num_drafts) + + ratio = (num_drafts - lower_num_draft) / (upper_num_draft - lower_num_draft) + lower_itl = batch_stats[lower_num_draft] + upper_itl = batch_stats[upper_num_draft] + return lower_itl + ratio * (upper_itl - lower_itl) + + def _compute_optimal_num_speculative_tokens(self, batch_size: int) -> int: + batch_stats = self._get_batch_stats(batch_size) + + max_goodput = -1.0 + for num_drafts in range(self.max_num_speculative_tokens + 1): + curr_al = 1 + sum(self.acceptance_rate_per_pos[:num_drafts]) + curr_itl = self._get_itl(batch_stats, num_drafts) + curr_goodput = curr_al / curr_itl + if curr_goodput > max_goodput: + max_goodput = curr_goodput + chosen_num_drafts = num_drafts + + return chosen_num_drafts + + def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int) -> None: + """Record draft/accept counts for online acceptance rate updates.""" + self.stats.observe_draft(num_draft_tokens, num_accepted_tokens) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index a03b707dd347..d9f8298d2212 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -399,6 +399,7 @@ def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor: def propose( self, + num_speculative_tokens: int, # [num_tokens] target_token_ids: torch.Tensor, # [num_tokens] or [3, num_tokens] when M-RoPE is enabled @@ -416,6 +417,8 @@ def propose( | list[dict[str, torch.Tensor]] | None = None, ) -> torch.Tensor: + self.num_speculative_tokens = num_speculative_tokens + batch_size = common_attn_metadata.batch_size() if self.method in ("eagle3", "dflash"): @@ -475,6 +478,17 @@ def propose( sample_hidden_states = last_hidden_states[token_indices_to_sample] + # No draft tokens requested (e.g. Dynamic SD decided K=0). + # The prefill forward pass above already ran to keep the drafter + # KV cache in sync, so just return an empty tensor. + if self.num_speculative_tokens == 0: + return torch.empty( + batch_size, + 0, + device=sample_hidden_states.device, + dtype=torch.int64, + ) + # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1 or self.parallel_drafting: draft_token_ids = self._greedy_sample(sample_hidden_states) diff --git a/vllm/v1/spec_decode/extract_hidden_states.py b/vllm/v1/spec_decode/extract_hidden_states.py index e26fa768a324..7c3b929e8729 100644 --- a/vllm/v1/spec_decode/extract_hidden_states.py +++ b/vllm/v1/spec_decode/extract_hidden_states.py @@ -27,7 +27,10 @@ class ExtractHiddenStatesProposer: def __init__(self, vllm_config: VllmConfig, device): assert vllm_config.speculative_config is not None - assert vllm_config.speculative_config.num_speculative_tokens == 1 + self.num_speculative_tokens = ( + vllm_config.speculative_config.num_speculative_tokens + ) + assert self.num_speculative_tokens == 1 if vllm_config.speculative_config.disable_padded_drafter_batch: raise ValueError( "disable_padded_drafter_batch is not supported with " @@ -71,6 +74,7 @@ def __init__(self, vllm_config: VllmConfig, device): def propose( self, + num_speculative_tokens: int, sampled_token_ids: torch.Tensor, target_hidden_states: list[torch.Tensor], common_attn_metadata: CommonAttentionMetadata, @@ -101,6 +105,7 @@ def propose( - Draft tokens matching sampled tokens, shape [batch_size, 1] - KV connector output (if KV transfer is active), else None """ + assert num_speculative_tokens == self.num_speculative_tokens assert self.model is not None and isinstance(target_hidden_states, list) # target_hidden_states is a list of tensors (one per layer) diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index 80b0f0a9870a..7adf7cff5f77 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -35,15 +35,18 @@ def __init__( self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens self.hidden_size = self.spec_config.draft_model_config.get_hidden_size() self.dtype = vllm_config.model_config.dtype + self.num_speculative_tokens = self.spec_config.num_speculative_tokens def propose( self, + num_speculative_tokens: int, target_hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None, # unused ) -> torch.Tensor: + assert num_speculative_tokens == self.num_speculative_tokens # Generate blocks and compute logits blocks = self.model(target_hidden_states) logits = self.model.compute_logits(blocks) diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index 9a41ff5c818c..b5c8fcc05969 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -28,12 +28,14 @@ class SpecDecodingStats: num_draft_tokens: int = 0 num_accepted_tokens: int = 0 num_accepted_tokens_per_pos: list[int] = field(default_factory=list) + num_draft_tokens_per_pos: list[int] = field(default_factory=list) @classmethod def new(cls, num_spec_tokens: int) -> "SpecDecodingStats": return cls( num_spec_tokens=num_spec_tokens, num_accepted_tokens_per_pos=[0] * num_spec_tokens, + num_draft_tokens_per_pos=[0] * num_spec_tokens, ) def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int): @@ -43,6 +45,8 @@ def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int): assert num_accepted_tokens <= self.num_spec_tokens for i in range(num_accepted_tokens): self.num_accepted_tokens_per_pos[i] += 1 + for i in range(num_draft_tokens): + self.num_draft_tokens_per_pos[i] += 1 class SpecDecodingLogging: diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 53199d0ce217..e0240d0e66b8 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -55,6 +55,7 @@ def __init__(self, vllm_config: VllmConfig): # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. self.propose( + self.k, [[]] * 1024, np.zeros(1024, dtype=np.int32), np.zeros((1024, self.max_model_len), dtype=np.int32), @@ -66,6 +67,7 @@ def batch_propose( valid_ngram_requests: list, num_tokens_no_spec: np.ndarray, token_ids_cpu: np.ndarray, + k: int, ) -> list[list[int]]: """Batch version of ngram proposer using numba for acceleration. @@ -78,6 +80,8 @@ def batch_propose( token_ids_cpu: Numpy array of shape (batch_size, max_model_len) representing the token IDs for each request. + k: + Number of speculative tokens to propose. Returns: list[list[int]]: @@ -110,7 +114,7 @@ def batch_propose( self.min_n, self.max_n, self.max_model_len, - self.k, + k, self.valid_ngram_draft, self.valid_ngram_num_drafts, ) @@ -130,6 +134,7 @@ def batch_propose( def propose( self, + num_speculative_tokens: int, sampled_token_ids: list[list[int]], num_tokens_no_spec: np.ndarray, token_ids_cpu: np.ndarray, @@ -137,6 +142,8 @@ def propose( | list[dict[str, torch.Tensor]] | None = None, # unused ) -> list[list[int]]: + assert num_speculative_tokens <= self.k + # find which requests need ngram proposals valid_ngram_requests = [] for i, sampled_ids in enumerate(sampled_token_ids): @@ -157,6 +164,7 @@ def propose( valid_ngram_requests, num_tokens_no_spec, token_ids_cpu, + num_speculative_tokens, ) return draft_token_ids diff --git a/vllm/v1/spec_decode/ngram_proposer_gpu.py b/vllm/v1/spec_decode/ngram_proposer_gpu.py index eb24a9c933e2..1f1566e4bfc4 100644 --- a/vllm/v1/spec_decode/ngram_proposer_gpu.py +++ b/vllm/v1/spec_decode/ngram_proposer_gpu.py @@ -313,6 +313,7 @@ def _generate_dummy_data( def propose( self, + num_speculative_tokens: int, num_tokens_no_spec: torch.Tensor, # [batch_size] token_ids_gpu: torch.Tensor, # [batch_size, max_len] valid_sampled_token_ids_gpu: torch.Tensor, # [batch_size, num_spec_tokens + 1] @@ -325,6 +326,7 @@ def propose( updated lengths, then run the kernel. Args: + num_speculative_tokens: Number of speculative tokens to propose. num_tokens_no_spec: Number of tokens per sequence (read-only) token_ids_gpu: Token IDs tensor (modified in-place with new tokens) valid_sampled_token_ids_gpu: Newly sampled tokens to scatter @@ -335,6 +337,7 @@ def propose( num_valid_draft_tokens: Count of leading valid draft tokens per request [batch_size] """ + assert num_speculative_tokens == self.k assert token_ids_gpu.device == self.device assert num_tokens_no_spec.device == self.device diff --git a/vllm/v1/spec_decode/suffix_decoding.py b/vllm/v1/spec_decode/suffix_decoding.py index fee5d97468f3..66137a006316 100644 --- a/vllm/v1/spec_decode/suffix_decoding.py +++ b/vllm/v1/spec_decode/suffix_decoding.py @@ -34,12 +34,14 @@ def __init__(self, vllm_config: VllmConfig): def propose( self, + num_speculative_tokens: int, input_batch: InputBatch, sampled_token_ids: list[list[int]], slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None, # unused ) -> list[list[int]]: + assert num_speculative_tokens == self.num_speculative_tokens """ Propose speculative tokens for each request in the input batch. Suffix Decoding will speculate a dynamic number of tokens for each request every decoding step, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a950827879ad..34c5fb8731fe 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -577,9 +577,11 @@ def __init__( self.rejection_sampler = RejectionSampler(self.sampler) self.num_spec_tokens = 0 + self.prev_num_spec_tokens = 0 self.valid_sampled_token_count_gpu: torch.Tensor | None = None if self.speculative_config: self.num_spec_tokens = self.speculative_config.num_speculative_tokens + self.prev_num_spec_tokens = self.num_spec_tokens draft_config = self.speculative_config.draft_model_config if draft_config is not None and draft_config.max_model_len is not None: self.effective_drafter_max_model_len = draft_config.max_model_len @@ -1669,7 +1671,7 @@ def _prepare_input_ids( spec_flattened_indices.extend( range(flattened_index - draft_len + 1, flattened_index + 1) ) - start = prev_index * self.num_spec_tokens + start = prev_index * self.prev_num_spec_tokens # prev_draft_token_indices is used to find which draft_tokens_id # should be copied to input_ids # example: prev draft_tokens_id [[1,2], [3,4], [5, 6]] @@ -4256,8 +4258,6 @@ def propose_draft_token_ids(sampled_token_ids): self._copy_valid_sampled_token_count( next_token_ids, valid_sampled_tokens_count ) - # Since we couldn't run the drafter, - # just use zeros for the draft tokens. self._draft_token_ids = torch.zeros( 1, device=self.device, dtype=torch.int32 ).expand(len(self.input_batch.req_ids), self.num_spec_tokens) @@ -4397,6 +4397,9 @@ def take_draft_token_ids(self) -> DraftTokenIds | None: def _copy_draft_token_ids_to_cpu( self, scheduler_output: "SchedulerOutput", zeros_only: bool = False ) -> None: + if torch.is_tensor(self._draft_token_ids): + assert isinstance(self._draft_token_ids, torch.Tensor) + self.prev_num_spec_tokens = self._draft_token_ids.shape[1] # Check if we need to copy draft tokens to CPU. In async scheduling, # we only copy when needed for structured output, penalties or bad_words. if self.use_async_scheduling and not ( @@ -4415,16 +4418,17 @@ def _copy_draft_token_ids_to_cpu( assert self.draft_token_ids_cpu is not None default_stream = torch.cuda.current_stream() num_reqs = draft_token_ids.shape[0] + num_spec_tokens = draft_token_ids.shape[1] with torch.cuda.stream(self.draft_token_ids_copy_stream): if not zeros_only: # Trigger async copy of draft token ids to cpu. self.draft_token_ids_copy_stream.wait_stream(default_stream) - self.draft_token_ids_cpu[:num_reqs].copy_( + self.draft_token_ids_cpu[:num_reqs, :num_spec_tokens].copy_( draft_token_ids, non_blocking=True ) else: # No copy needed, just zero-out cpu tensor. - self.draft_token_ids_cpu[:num_reqs] = 0 + self.draft_token_ids_cpu[:num_reqs, :num_spec_tokens] = 0 self.draft_token_ids_event.record() def _get_draft_token_ids_cpu(self) -> tuple[list[list[int]], list[str]]: @@ -4436,7 +4440,11 @@ def _get_draft_token_ids_cpu(self) -> tuple[list[list[int]], list[str]]: assert self.draft_token_ids_event is not None assert self.draft_token_ids_cpu is not None self.draft_token_ids_event.synchronize() - return self.draft_token_ids_cpu[: len(req_ids)].tolist(), req_ids + assert isinstance(self._draft_token_ids, torch.Tensor) + num_spec_tokens = self._draft_token_ids.shape[1] + return self.draft_token_ids_cpu[ + : len(req_ids), :num_spec_tokens + ].tolist(), req_ids def _copy_valid_sampled_token_count( self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor @@ -4487,12 +4495,14 @@ def propose_draft_token_ids( num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens spec_config = self.speculative_config assert spec_config is not None + num_spec_tokens_to_schedule = scheduler_output.num_spec_tokens_to_schedule if spec_config.method == "ngram": from vllm.v1.spec_decode.ngram_proposer import NgramProposer assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.drafter.propose( + num_spec_tokens_to_schedule, sampled_token_ids, self.input_batch.num_tokens_no_spec, self.input_batch.token_ids_cpu, @@ -4518,6 +4528,7 @@ def propose_draft_token_ids( batch_size = next_token_ids.shape[0] draft_token_ids, num_valid_draft_tokens = self.drafter.propose( + num_spec_tokens_to_schedule, self.num_tokens_no_spec_gpu[:batch_size], self.token_ids_gpu_tensor[:batch_size], valid_sampled_token_ids_gpu, @@ -4539,7 +4550,10 @@ def propose_draft_token_ids( assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, SuffixDecodingProposer) draft_token_ids = self.drafter.propose( - self.input_batch, sampled_token_ids, slot_mappings=slot_mappings + num_spec_tokens_to_schedule, + self.input_batch, + sampled_token_ids, + slot_mappings=slot_mappings, ) elif spec_config.method == "medusa": assert isinstance(sampled_token_ids, list) @@ -4563,6 +4577,7 @@ def propose_draft_token_ids( hidden_states = sample_hidden_states[indices] draft_token_ids = self.drafter.propose( + num_speculative_tokens=num_spec_tokens_to_schedule, target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, slot_mappings=slot_mappings, @@ -4580,6 +4595,7 @@ def propose_draft_token_ids( target_hidden_states = [h[:num_scheduled_tokens] for h in aux_hidden_states] draft_token_ids = self.drafter.propose( + num_speculative_tokens=num_spec_tokens_to_schedule, sampled_token_ids=sampled_token_ids, target_hidden_states=target_hidden_states, common_attn_metadata=common_attn_metadata, @@ -4704,6 +4720,7 @@ def propose_draft_token_ids( mm_embed_inputs = None draft_token_ids = self.drafter.propose( + num_speculative_tokens=num_spec_tokens_to_schedule, target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states,