From 8e7bcc97a135721adc32fe55f81e2f1c14c52e70 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Fri, 10 Oct 2025 14:43:22 +0000 Subject: [PATCH 01/39] save work Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/config/speculative.py | 39 ++++ vllm/model_executor/models/interfaces.py | 29 +++ vllm/v1/spec_decode/dynamic.py | 138 +++++++++++++ vllm/v1/spec_decode/dynamic_profiling.py | 63 ++++++ vllm/v1/spec_decode/eagle.py | 8 +- vllm/v1/spec_decode/medusa.py | 1 + vllm/v1/spec_decode/ngram_proposer.py | 14 +- .../v1/spec_decode/online_profiling_client.py | 185 ++++++++++++++++++ .../v1/spec_decode/online_profiling_server.py | 122 ++++++++++++ vllm/v1/worker/gpu_model_runner.py | 22 +++ 10 files changed, 616 insertions(+), 5 deletions(-) create mode 100644 vllm/v1/spec_decode/dynamic.py create mode 100644 vllm/v1/spec_decode/dynamic_profiling.py create mode 100644 vllm/v1/spec_decode/online_profiling_client.py create mode 100644 vllm/v1/spec_decode/online_profiling_server.py diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index d5c6d1d4d866..f3f9d9deeb07 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -36,6 +36,41 @@ MTP_MODEL_TYPES = ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp", "ernie_mtp", "qwen3_next_mtp", "longcat_flash_mtp") +@dataclass +class DynamicSpeculativeConfig: + """A mapping from batch size to optimal number of drafts to use for that + batch size. This is used to dynamically adjust the number of drafts used + based on the current batch size.""" + optimal_num_speculative_tokens: dict[int, int] = None + + """Whether the statistics are updated online or not during inference.""" + is_online: bool = False + + """ + Batch statistics for different batch sizes and number of drafts. + The structure is as follows: + { + batch_size: { + num_drafts: itl (i.e., inter token latency in ms) + } + } + + e.g., + { + 1: { 0: 6.87, 3: 9.41, 5: 10.8}, + 4: { 0: 7.3, 3: 9.95, 5: 11.59}, + } + + where bs 1 at K=3 has itl 9.41ms. K=0 means no speculative decoding. + """ + batch_stats: dict[int, dict[int, float]] = None + + """Maximum number of speculative tokens supported in the statistics.""" + max_num_speculative_tokens: int = None + + """Acceptance rate per position on an offline dataset.""" + acceptance_rate_per_pos: list[float] = None + @config @dataclass @@ -117,6 +152,10 @@ class SpeculativeConfig: """Whether to disable the periodic printing of stage times in speculative decoding.""" + # dynamic speculative decoding control + """Configuration for dynamic speculative decoding, if provided.""" + dynamic_config: Optional[DynamicSpeculativeConfig] = None + # params generated in the post-init stage draft_model_config: SkipValidation[ModelConfig] = None # type: ignore """The configuration of the draft model initialized internal.""" diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index c95c63cd8534..415249561b3c 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -905,6 +905,35 @@ def supports_v0_only( return getattr(model, "supports_v0_only", False) +@runtime_checkable +class SpeculativeDecodingProposer(Protocol): + """The interface required for all models that support speculative decoding + proposer.""" + + """Number of speculative tokens to propose.""" + num_speculative_tokens: int + + def propose( + self, + *args, + **kwargs: object, + ) -> Union[torch.Tensor, list[list[int]]]: + """ + Propose multiple tokens for speculative decoding. + + Args: + input_ids: Input token IDs of shape (batch_size, seq_len). + attention_mask: Optional attention mask of shape + (batch_size, seq_len). + **kwargs: Additional model-specific arguments. + + Returns: + A tensor of shape (batch_size, num_proposed_tokens) or a list of + lists of token IDs, where each inner list contains the proposed + token IDs for the corresponding batch item. + """ + ... + @runtime_checkable class SupportsEagle3(Protocol): """The interface required for models that support diff --git a/vllm/v1/spec_decode/dynamic.py b/vllm/v1/spec_decode/dynamic.py new file mode 100644 index 000000000000..d0676a000093 --- /dev/null +++ b/vllm/v1/spec_decode/dynamic.py @@ -0,0 +1,138 @@ +from typing import Any, Tuple, Optional, List +from vllm.vllm.config.speculative import DynamicSpeculativeConfig + + +_DYNAMIC_STATS = { + "max_num_speculative_tokens": 7, + "acceptance_rate_per_pos": [0.68, 0.39, 0.20, 0.10, 0.06, 0.03, 0.02], # E1 + # "acceptance_rate_per_pos": [0.76, 0.54, 0.39, 0.28, 0.21, 0.15, 0.12], # E3 + # "acceptance_rate_per_pos": [0.17, 0.15, 0.12, 0.01, 0.01, 0.01, 0.01], # E1 - low + "batch_stats": { + 1: { 0: 6.87, 1: 7.97, 3: 9.41, 4: 9.91, 5: 10.8, 7: 12.29, }, + 4: { 0: 6.87, 1: 7.97, 3: 9.41, 4: 9.91, 5: 10.8, 7: 12.29, }, + 16: { 0: 7.3, 1: 8.39, 3: 9.95, 4: 10.8, 5: 11.59, 7: 13.11, }, + 32: { 0: 7.64, 1: 8.97, 3: 10.78, 4: 11.79, 5: 12.81, 7: 14.86, }, + 64: { 0: 8.53, 1: 10.44, 3: 13.16, 4: 15.7, 5: 17.54, 7: 120.57, }, + } +} + +class DynamicSpeculativeDecodingManager: + def __init__(self, + dynamic_config: Optional[DynamicSpeculativeConfig], + vllm_max_batch_size: int, + vllm_num_speculative_tokens: int): + self.dynamic_config = dynamic_config + self.vllm_max_batch_size = vllm_max_batch_size + self.optimal_num_speculative_tokens = self.dynamic_config.optimal_num_speculative_tokens + self.batch_stats = self.dynamic_config.batch_stats + self.available_batch_sizes = sorted(self.dynamic_config.batch_stats.keys()) + + # Sanity check + assert vllm_num_speculative_tokens <= self.dynamic_config.max_num_speculative_tokens, \ + "vllm_num_speculative_tokens must be <= max_num_speculative_tokens" + + if self.dynamic_config.is_online: + assert self.dynamic_config.max_num_speculative_tokens == len(self.dynamic_config.acceptance_rate_per_pos), \ + "max_num_speculative_tokens must be equal to the length of acceptance_rate_per_pos" + assert self.dynamic_config.max_num_speculative_tokens > 0, "max_num_speculative_tokens must be > 0" + assert all(0 < a < 1 for a in self.dynamic_config.acceptance_rate_per_pos), "all acceptance_rate_per_pos values must be in (0, 1)" + assert 1 in self.dynamic_config.batch_stats, "batch size 1 must be available" + assert vllm_max_batch_size in self.dynamic_config.batch_stats, \ + f"vllm max_num_seqs {vllm_max_batch_size} must be available" + + for bs in self.available_batch_sizes: + assert bs > 0 + assert 0 in self.dynamic_config.batch_stats[bs], \ + f"batch size {bs} must have draft 0 stats" + assert 1 in self.dynamic_config.batch_stats[bs], \ + f"batch size {bs} must have draft 1 stats" + assert sorted(self.dynamic_config.batch_stats[bs].keys()) == \ + list(self.dynamic_config.batch_stats[bs].keys()), \ + f"batch size {bs} draft keys must be sorted" + + + + def get_optimal_num_speculative_tokens(self, batch_size: int) -> int: + return self.optimal_num_speculative_tokens[batch_size] + + + def update_optimal_num_speculative_tokens(self): + self.optimal_num_speculative_tokens = { + bs: self._compute_optimal_num_speculative_tokens(bs) \ + for bs in range(1, self.vllm_max_batch_size) + } + + + def _get_batch_stats(self, batch_size: int) -> dict: + # import pdb; pdb.set_trace() + if batch_size not in self.batch_stats: + # find the nearest batch size smaller and bigger than the given batch size + # and return the weighted avg of their stats + print(f"Finding batch stats for batch_size: {batch_size} in self.available_batch_sizes: {self.available_batch_sizes}") + + smaller_bs = [bs for bs in self.available_batch_sizes if bs < batch_size] + smaller_bs = max(smaller_bs) if len(smaller_bs) else self.available_batch_sizes[0] + larger_bs = [bs for bs in self.available_batch_sizes if bs > batch_size] + larger_bs = min(larger_bs) if len(larger_bs) else self.available_batch_sizes[-1] + + + # REMOVE + print(f"smaller_bs: {smaller_bs}, larger_bs: {larger_bs}, batch_size: {batch_size}") + + smaller_bs_stat = self.batch_stats[smaller_bs] + larger_bs_stat = self.batch_stats[larger_bs] + + ratio = (batch_size - smaller_bs) / (larger_bs - smaller_bs) + + # REMOVE + print(f"ratio: {ratio}") + + avg_stat: dict[int, float] = {} + for k in smaller_bs_stat.keys(): + avg_stat[k] = smaller_bs_stat[k] + ratio * (larger_bs_stat[k] - smaller_bs_stat[k]) + + return avg_stat + else: + return self.batch_stats[batch_size] + + + def _get_itl(self, batch_stats, num_drafts: int) -> float: + if num_drafts in batch_stats: + return batch_stats[num_drafts] + else: + lower_num_draft = max(k for k in batch_stats.keys() if k < num_drafts) + upper_num_draft = min(k for k in batch_stats.keys() if k > num_drafts) + + # REMOVE + # print(f"lower_num_draft: {lower_num_draft}, upper_num_draft: {upper_num_draft}, num_drafts: {num_drafts}") + + ratio = (num_drafts - lower_num_draft) / (upper_num_draft - lower_num_draft) + lower_itl = batch_stats[lower_num_draft] + upper_itl = batch_stats[upper_num_draft] + return lower_itl + ratio * (upper_itl - lower_itl) + + + def _compute_optimal_num_speculative_tokens(self, batch_size: int) -> int: + batch_stats = self._get_batch_stats(batch_size) + + max_goodput = -1 + for num_drafts in range(self.dynamic_sd_stats["max_num_speculative_tokens"] + 1): + curr_al = 1 + sum(self.dynamic_sd_stats["acceptance_rate_per_pos"][:num_drafts]) + curr_itl = self._get_itl(batch_stats, num_drafts) + curr_goodput = curr_al / curr_itl + if curr_goodput > max_goodput: + max_goodput = curr_goodput + chosen_num_drafts = num_drafts + + # REMOVE + print(f"num_drafts: {num_drafts}, al: {curr_al}, itl: {curr_itl}, goodput: {curr_goodput}") + + return chosen_num_drafts + + +if __name__ == "__main__": + # print(_get_batch_stats(21)) + dynamic_sd = DynamicSpeculativeDecodingManager(_DYNAMIC_STATS) + for i in range(4, 64, 4): + print("\n====================================") + print(f"bs: {i}, optimal num drafts: {dynamic_sd.get_optimal_num_speculative_tokens(i)}") \ No newline at end of file diff --git a/vllm/v1/spec_decode/dynamic_profiling.py b/vllm/v1/spec_decode/dynamic_profiling.py new file mode 100644 index 000000000000..abd57915628d --- /dev/null +++ b/vllm/v1/spec_decode/dynamic_profiling.py @@ -0,0 +1,63 @@ +import re +import os +import json +import argparse +from vllm.v1.spec_decode.online_profiling_client import (NGRAM_FMT, EAGLE_FMT) + + +def parse_itl(args): + """ + DynamicSpeculativeConfig.batch_stats: dict + The structure is as follows: + { + batch_size: { + num_drafts: itl (i.e., inter token latency in ms) + } + } + """ + batch_stats = {} + + for method in ["vanilla", args.sd_method]: + # find the names of all log files in this folder + args.benchmark_path = os.path.join(args.benchmark_path_parent, method) + all_log_files = [f for f in os.listdir(args.benchmark_path) \ + if os.path.isfile(os.path.join(args.benchmark_path, f)) \ + and f.endswith(".txt")] + + # parse the log files to get the config params + for log_file in all_log_files: + # find bs + bs = re.search(r'_bs-(\d+)', log_file).group(1) + + # find sd params + spec_config_str = log_file.split("_")[0] + if method == "ngram": + FMT = NGRAM_FMT.replace("{}", "(.+)") + min, max, k = re.match(FMT, spec_config_str).groups() + elif method == "eagle": + FMT = EAGLE_FMT.replace("{}", "(.+)") + k = re.match(FMT, spec_config_str).groups()[0] + + # read the log file to get the itl + with open(os.path.join(args.benchmark_path, log_file), "r") as f: + data = json.load(f) + itl = data["median_itl_ms"] + + # add to batch_stats + if int(bs) not in batch_stats: + batch_stats[int(bs)] = {} + + batch_stats[int(bs)][int(k)] = itl + + return batch_stats + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--sd-method", type=str, default=None) + parser.add_argument("--benchmark-path-parent", type=str, default=None, help="Root folder which has the log files") + + args = parser.parse_args() + assert args.sd_method in ["ngram", "eagle"], "Invalid method specified." + + batch_stats = parse_itl(args) \ No newline at end of file diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index dc6db0138806..e4270bd2fcd5 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -34,13 +34,14 @@ from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.ubatching import dbo_current_ubatch_id +from vllm.vllm.model_executor.models.interfaces import SpeculativeDecodingProposer logger = init_logger(__name__) PADDING_SLOT_ID = -1 -class EagleProposer: +class EagleProposer(SpeculativeDecodingProposer): def __init__( self, @@ -177,6 +178,7 @@ def _set_positions(self, num_tokens: int, positions: torch.Tensor): def propose( self, + optimal_num_speculative_tokens: Optional[int], # [num_tokens] target_token_ids: torch.Tensor, # [num_tokens] or [3, num_tokens] when M-RoPE is enabled @@ -191,6 +193,10 @@ def propose( mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]] = None, ) -> torch.Tensor: + # Use optimal num speculative tokens if provided + if optimal_num_speculative_tokens is not None: + self.num_speculative_tokens = optimal_num_speculative_tokens + num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index 70b29c05c2a5..9a152ca5951a 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -33,6 +33,7 @@ def __init__( draft_model_config.get_hidden_size( ) self.dtype = vllm_config.model_config.dtype + self.num_speculative_tokens = None def propose( self, diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index aed050a3540c..cfe9070b90cc 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -5,10 +5,12 @@ import numpy as np from numba import get_num_threads, jit, njit, prange, set_num_threads +from typing import Optional from vllm.config import VllmConfig +from vllm.vllm.model_executor.models.interfaces import SpeculativeDecodingProposer -class NgramProposer: +class NgramProposer(SpeculativeDecodingProposer): def __init__(self, vllm_config: VllmConfig): assert vllm_config.speculative_config is not None @@ -22,13 +24,13 @@ def __init__(self, vllm_config: VllmConfig): # Number of tokens follow the match. If there are less than k # tokens follow the match, we will return the maximum amount of # tokens until the end. - self.k = vllm_config.speculative_config.num_speculative_tokens + self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens # Maximum length of the model. self.max_model_len = vllm_config.model_config.max_model_len # Pre-allocate buffers for numba batch propose. max_num_seqs = vllm_config.scheduler_config.max_num_seqs - self.valid_ngram_draft = np.zeros((max_num_seqs, self.k), + self.valid_ngram_draft = np.zeros((max_num_seqs, self.num_speculative_tokens), dtype=np.int32) self.valid_ngram_num_drafts = np.zeros((max_num_seqs), dtype=np.int32) @@ -104,7 +106,7 @@ def batch_propose( batch_propose_numba(valid_ngram_requests, num_tokens_no_spec, token_ids_cpu, self.min_n, self.max_n, - self.max_model_len, self.k, + self.max_model_len, self.num_speculative_tokens, self.valid_ngram_draft, self.valid_ngram_num_drafts) @@ -123,12 +125,16 @@ def batch_propose( def propose( self, + optimal_num_speculative_tokens: Optional[int], sampled_token_ids: list[list[int]], req_ids: list[str], num_tokens_no_spec: np.ndarray, token_ids_cpu: np.ndarray, spec_decode_unsupported_reqs: set, ) -> list[list[int]]: + # Use optimal num speculative tokens if provided + if optimal_num_speculative_tokens is not None: + self.num_speculative_tokens = optimal_num_speculative_tokens # find which requests need ngram proposals valid_ngram_requests = [] diff --git a/vllm/v1/spec_decode/online_profiling_client.py b/vllm/v1/spec_decode/online_profiling_client.py new file mode 100644 index 000000000000..be4e93e39f85 --- /dev/null +++ b/vllm/v1/spec_decode/online_profiling_client.py @@ -0,0 +1,185 @@ +import os +import subprocess +import pandas as pd +from dataclasses import dataclass +import argparse +from vllm.vllm.v1.spec_decode.online_profiling_server import ( + start_server, + kill_server, + setup_server) + + +@dataclass +class Dataset: + name: str + config: list + +NGRAM_FMT = "min-{min}-max-{max}-k-{k}" +EAGLE_FMT = "k-{k}" + +def run_command(command): + try: + result = subprocess.run(f"bash -c '{command}'", + shell=True, + check=True, + capture_output=True, + text=True) + print("Output:") + print(result.stdout) + except subprocess.CalledProcessError as e: + print("Error:") + print(e.stderr) + + +def run_benchmarks(args): + + # setup server + setup_server() + + port=9001 + all_sampling_profile=[ + {'temperature': 0, 'topp': 1}, # greedy + ] + + MTBENCH_CONFIG = [{"num_samples_per_seq": 20}] + + all_bench_dataset = [ + Dataset(name = "philschmid/mt-bench", config = MTBENCH_CONFIG), + ] + + assert (all(len(ds.config) > 0 for ds in all_bench_dataset)), "Each dataset must have at least one config" + + all_ngram_params = [{"min": 2, "max": 5, "k": k} for k in args.num_speculative_tokens] + all_eagle_params = args.num_speculative_tokens + + # ablation + num_exp_run = 0 + for tp in args.tp: + for spec_method in args.method_list: + # collate all spec configs to run for a given method + all_spec_config = [] + if spec_method == "ngram": + for ngram_params in all_ngram_params: + all_spec_config.append({ + "method": "ngram", + "num_speculative_tokens": ngram_params['k'], + "prompt_lookup_max": ngram_params['max'], + "prompt_lookup_min": ngram_params['min'], + }) + elif spec_method == "eagle": + for eagle_k in all_eagle_params: + all_spec_config.append({ + "method": "eagle", + "model": args.draft_dir, + "num_speculative_tokens": eagle_k, + "draft_tensor_parallel_size": tp, + }) + else: + # vanilla case + all_spec_config.append(None) + + for spec_config in all_spec_config: + # start server + server_process = start_server(port=port, + target_model_dir=args.model_dir, + spec_config=spec_config, + tp=tp, + max_vllm_bs=args.max_vllm_batch_size, + dry_run=args.dry_run) + + # start client + for bench_concurrency in args.batch_size_list: + for bench_dataset_object in all_bench_dataset: + bench_dataset = bench_dataset_object.name + for bench_config in bench_dataset_object.config: + for sampling_profile in all_sampling_profile: + bench_temperature = sampling_profile['temperature'] + bench_topp = sampling_profile['topp'] + + spec_config_str = "vanilla" + if spec_method == "ngram": + spec_config_str = NGRAM_FMT.format( + min=spec_config['prompt_lookup_min'], + max=spec_config['prompt_lookup_max'], + k=spec_config['num_speculative_tokens'] + ) + elif spec_method == "eagle": + spec_config_str = EAGLE_FMT.format( + k=spec_config['num_speculative_tokens'] + ) + + # dataset specific config + if "philschmid/mt-bench" in bench_dataset: + bench_config_str = f"mt_bench" + num_prompts = bench_config["num_prompts"] * bench_concurrency + bench_vllm_serve_config = f'--dataset-name hf --dataset-path {bench_dataset} --num-prompts {num_prompts}' + + print(f"Number of prompts in {bench_dataset}: {num_prompts}") + + # create dir if not exists + result_dir = f"{args.result_dir}/tp-{tp}_temp-{bench_temperature}_top_p-{bench_topp}/{bench_dataset}/{spec_method}/online/" + if not os.path.exists(result_dir): + os.makedirs(result_dir) + + cmd = f'''time vllm bench serve --port {port} --save-result --save-detailed \ + --model {args.model_dir} \ + --backend openai-chat \ + --endpoint /v1/chat/completions \ + {bench_vllm_serve_config} \ + --max-concurrency {bench_concurrency} \ + --temperature={bench_temperature} \ + --top-p={bench_topp} \ + --result-dir "{result_dir}" \ + --result-filename "{spec_config_str}_{bench_config_str}_bs-{bench_concurrency}_{args.extra_log_arg}.txt"''' + + print(cmd) + num_exp_run += 1 + + if not args.dry_run: + run_command(cmd) + + # server teardown: kill server and any gpu processes + kill_server(port, server_process) + + print(f"Total number of experiments run: {num_exp_run}") + + +# time python3 vllm/cohere/utils/eagle/sweep_ngram_eagle_online_benchmark.py +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dry-run", action="store_true", help="Run in dry run mode. If set, commands will be printed but not executed.") + parser.add_argument("--model-dir", type=str, default=None) + parser.add_argument("--draft-dir", type=str, default=None) + parser.add_argument("--method-list", type=list[str], default=["vanilla", "eagle"]) + parser.add_argument("--num-speculative-tokens-list", type=list[int], default=[1, 3, 5]) + parser.add_argument("--batch-size-list", type=list[int], default=[1, 4, 8, 16, 32, 64, 128]) + parser.add_argument("--max-vllm-batch-size", type=int, help="Max vllm server batch size (max concurrency)") + parser.add_argument("--tp-list", type=list[int], default=[1]) + parser.add_argument("--result-dir", type=str, default="./log") + parser.add_argument("--extra-log-arg", type=str, default="") + args = parser.parse_args() + + assert all([method in ["vanilla", "ngram", "eagle", "eagle3"] for method in args.method_list]), \ + "invalid method in method_list" + + assert 1 in args.batch_size_list, "batch_size must contain 1" + assert 1 in args.num_speculative_tokens_list, "num_speculative_tokens must contain 1" + assert args.max_vllm_batch_size == max(args.batch_size_list), \ + "max_vllm_batch_size must be equal to max of batch_size" + + model_dir = args.model_dir + args.model_dir = "meta-llama/Llama-3.1-8B-Instruct" if args.model_dir is None else args.model_dir + + if args.method == "eagle" or args.method == "eagle3": + if args.method == "eagle" and args.eagle_dir is None: + args.eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" + + elif args.method == "eagle3" and args.eagle_dir is None: + args.eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" + speculative_config = { + "method": args.method, + "model": args.eagle_dir, + "num_speculative_tokens": args.num_spec_tokens, + } + + run_benchmarks(args) \ No newline at end of file diff --git a/vllm/v1/spec_decode/online_profiling_server.py b/vllm/v1/spec_decode/online_profiling_server.py new file mode 100644 index 000000000000..7d512079eb7b --- /dev/null +++ b/vllm/v1/spec_decode/online_profiling_server.py @@ -0,0 +1,122 @@ +import os +import subprocess +import json +import time +import shutil + +""" +Utility functions to manage the vLLM server for online profiling. +Main functions are setup_server(), start_server(), and kill_server(). +""" + +def wait_for_server(port: int) -> bool: + timeout = 1200 # 20 mins + start_time = time.time() + while time.time() - start_time < timeout: + try: + subprocess.run(["curl", "-X", "POST", f"localhost:{port}/v1/completions"], check=True) + return True + except subprocess.CalledProcessError: + time.sleep(10) # wait for 10 seconds before retrying + return False + + +def kill_gpu_processes(port: int): + subprocess.run(["ps", "-aux"]) + subprocess.run([f"lsof -t -i:{port} | xargs -r kill -9"], shell=True) + + # Use ps to list all Python processes and grep to exclude the specific one + command = ["ps", "aux"] + ps_output = subprocess.check_output(command, text=True) + + # Do not kill this process + filename = os.path.basename(__file__) + pids_to_kill = [] + for line in ps_output.split("\n"): + if "python3" in line and filename not in line: + pid = line.split()[1] + pids_to_kill.append(pid) + + # Kill other processes + for pid in pids_to_kill: + subprocess.run(["kill", "-9", pid]) + + # Wait until all GPUs have memory usage < 1000 MB + if shutil.which("nvidia-smi"): + while True: + # Get GPU memory usage for all GPUs + memory_usage = subprocess.check_output(["nvidia-smi", + "--query-gpu=memory.used", + "--format=csv,noheader,nounits"], + text=True) + # Split the output into individual GPU memory usage values + gpu_memory_usage = [int(x) for x in memory_usage.strip().split("\n")] + # Check if any GPU has memory usage >= 1000 MB + if all(usage < 1000 for usage in gpu_memory_usage): + break + time.sleep(1) + elif shutil.which("amd-smi"): + while True: + memory_usage = subprocess.check_output(["amd-smi", + "metric", + "-g", + "0"], + text=True) + used_vram = int(memory_usage.split("USED_VRAM")[1].split()[0]) + if used_vram < 1000: + break + time.sleep(1) + + subprocess.run(["rm", "-rf", "~/.config/vllm"]) + + +def setup_server(): + # install dependencies + dependencies = ["lsof", "curl", "pgrep"] + for dep in dependencies: + if not shutil.which(dep): + subprocess.run(["apt-get", "update"]) + subprocess.run(["apt-get", "install", "-y", dep]) + + +def start_server(port: int, + target_model_dir: str, + spec_config: dict | None, + tp: int, + max_vllm_bs: int, + dry_run: bool = False) -> subprocess.Popen | None: + + # NOTE: no Prompt Caching, but enabled chunked prefill + server_command = f"""VLLM_USE_V1=1 vllm serve {target_model_dir} \ + --disable-log-requests --port {port} \ + --gpu_memory_utilization 0.95 \ + --max_num_seqs {max_vllm_bs} \ + --tensor_parallel_size {tp} \ + --enable-chunked-prefill \ + --no-enable-prefix-caching """ + + if spec_config: + speculative_config_json_serialized = json.dumps(spec_config).replace('"', '\\"') + server_command += f'--speculative_config "{speculative_config_json_serialized}" ' + + print(f"Server command: {server_command}") + + # start vllm server + if not dry_run: + server_process = subprocess.Popen(server_command, shell=True) + + if wait_for_server(port): + print("vllm server is up and running.") + else: + print("vllm failed to start within the timeout period.") + server_process.kill() + + return server_process + else: + return None + + +def kill_server(port: int, server_process: subprocess.Popen | None): + if server_process: + server_process.kill() + kill_gpu_processes(port) \ No newline at end of file diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index efb4a8c0054f..bac1d5b7177c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -97,6 +97,7 @@ from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.spec_decode.dynamic import DynamicSpeculativeDecodingManager from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -297,6 +298,12 @@ def __init__( f"{self.speculative_config.method}") self.rejection_sampler = RejectionSampler() + # setup Dynamic Speculative Decoding + self.dynamic_sd_manager = DynamicSpeculativeDecodingManager( + self.speculative_config.dynamic_config, + self.vllm_config.scheduler_config.max_num_seqs, + self.vllm_config.speculative_config.num_speculative_tokens) + # Request states. self.requests: dict[str, CachedRequestState] = {} self.comm_stream = torch.cuda.Stream() @@ -2568,11 +2575,25 @@ def propose_draft_token_ids( spec_decode_metadata: Optional[SpecDecodeMetadata], common_attn_metadata: CommonAttentionMetadata, ) -> Union[list[list[int]], torch.Tensor]: + + optimal_num_speculative_tokens = None + if self.dynamic_sd_manager: + batch_size = self.input_batch.num_reqs + optimal_num_speculative_tokens = self.dynamic_sd_manager.\ + get_optimal_num_speculative_tokens( + self.input_batch.num_reqs + ) + + # REMOVE + print(f"Batch size: {batch_size}, " + f"Optimal num speculative tokens: {optimal_num_speculative_tokens}") + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.drafter.propose( + optimal_num_speculative_tokens, sampled_token_ids, self.input_batch.req_ids, self.input_batch.num_tokens_no_spec, self.input_batch.token_ids_cpu, @@ -2677,6 +2698,7 @@ def propose_draft_token_ids( mm_embed_inputs = None draft_token_ids = self.drafter.propose( + optimal_num_speculative_tokens=optimal_num_speculative_tokens, target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, From bee764bc39e21a838ebe1eb83182bfe166c16225 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Sun, 4 Jan 2026 04:16:53 +0000 Subject: [PATCH 02/39] move to dynamic folder and scripts are in working condition Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/config/speculative.py | 8 +- vllm/v1/spec_decode/dynamic.py | 138 ------------ vllm/v1/spec_decode/dynamic/dynamic.py | 169 +++++++++++++++ .../dynamic/online_profiling_client.py | 197 ++++++++++++++++++ .../{ => dynamic}/online_profiling_server.py | 42 +++- .../process_benchmark_results.py} | 32 ++- vllm/v1/spec_decode/eagle.py | 2 +- vllm/v1/spec_decode/ngram_proposer.py | 2 +- .../v1/spec_decode/online_profiling_client.py | 185 ---------------- vllm/v1/worker/gpu_model_runner.py | 11 +- 10 files changed, 438 insertions(+), 348 deletions(-) delete mode 100644 vllm/v1/spec_decode/dynamic.py create mode 100644 vllm/v1/spec_decode/dynamic/dynamic.py create mode 100644 vllm/v1/spec_decode/dynamic/online_profiling_client.py rename vllm/v1/spec_decode/{ => dynamic}/online_profiling_server.py (81%) rename vllm/v1/spec_decode/{dynamic_profiling.py => dynamic/process_benchmark_results.py} (65%) delete mode 100644 vllm/v1/spec_decode/online_profiling_client.py diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index f3f9d9deeb07..707cdbc207f3 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -38,10 +38,10 @@ @dataclass class DynamicSpeculativeConfig: - """A mapping from batch size to optimal number of drafts to use for that - batch size. This is used to dynamically adjust the number of drafts used - based on the current batch size.""" - optimal_num_speculative_tokens: dict[int, int] = None + # """A mapping from batch size to optimal number of drafts to use for that + # batch size. This is used to dynamically adjust the number of drafts used + # based on the current batch size.""" + # optimal_num_speculative_tokens: dict[int, int] = None """Whether the statistics are updated online or not during inference.""" is_online: bool = False diff --git a/vllm/v1/spec_decode/dynamic.py b/vllm/v1/spec_decode/dynamic.py deleted file mode 100644 index d0676a000093..000000000000 --- a/vllm/v1/spec_decode/dynamic.py +++ /dev/null @@ -1,138 +0,0 @@ -from typing import Any, Tuple, Optional, List -from vllm.vllm.config.speculative import DynamicSpeculativeConfig - - -_DYNAMIC_STATS = { - "max_num_speculative_tokens": 7, - "acceptance_rate_per_pos": [0.68, 0.39, 0.20, 0.10, 0.06, 0.03, 0.02], # E1 - # "acceptance_rate_per_pos": [0.76, 0.54, 0.39, 0.28, 0.21, 0.15, 0.12], # E3 - # "acceptance_rate_per_pos": [0.17, 0.15, 0.12, 0.01, 0.01, 0.01, 0.01], # E1 - low - "batch_stats": { - 1: { 0: 6.87, 1: 7.97, 3: 9.41, 4: 9.91, 5: 10.8, 7: 12.29, }, - 4: { 0: 6.87, 1: 7.97, 3: 9.41, 4: 9.91, 5: 10.8, 7: 12.29, }, - 16: { 0: 7.3, 1: 8.39, 3: 9.95, 4: 10.8, 5: 11.59, 7: 13.11, }, - 32: { 0: 7.64, 1: 8.97, 3: 10.78, 4: 11.79, 5: 12.81, 7: 14.86, }, - 64: { 0: 8.53, 1: 10.44, 3: 13.16, 4: 15.7, 5: 17.54, 7: 120.57, }, - } -} - -class DynamicSpeculativeDecodingManager: - def __init__(self, - dynamic_config: Optional[DynamicSpeculativeConfig], - vllm_max_batch_size: int, - vllm_num_speculative_tokens: int): - self.dynamic_config = dynamic_config - self.vllm_max_batch_size = vllm_max_batch_size - self.optimal_num_speculative_tokens = self.dynamic_config.optimal_num_speculative_tokens - self.batch_stats = self.dynamic_config.batch_stats - self.available_batch_sizes = sorted(self.dynamic_config.batch_stats.keys()) - - # Sanity check - assert vllm_num_speculative_tokens <= self.dynamic_config.max_num_speculative_tokens, \ - "vllm_num_speculative_tokens must be <= max_num_speculative_tokens" - - if self.dynamic_config.is_online: - assert self.dynamic_config.max_num_speculative_tokens == len(self.dynamic_config.acceptance_rate_per_pos), \ - "max_num_speculative_tokens must be equal to the length of acceptance_rate_per_pos" - assert self.dynamic_config.max_num_speculative_tokens > 0, "max_num_speculative_tokens must be > 0" - assert all(0 < a < 1 for a in self.dynamic_config.acceptance_rate_per_pos), "all acceptance_rate_per_pos values must be in (0, 1)" - assert 1 in self.dynamic_config.batch_stats, "batch size 1 must be available" - assert vllm_max_batch_size in self.dynamic_config.batch_stats, \ - f"vllm max_num_seqs {vllm_max_batch_size} must be available" - - for bs in self.available_batch_sizes: - assert bs > 0 - assert 0 in self.dynamic_config.batch_stats[bs], \ - f"batch size {bs} must have draft 0 stats" - assert 1 in self.dynamic_config.batch_stats[bs], \ - f"batch size {bs} must have draft 1 stats" - assert sorted(self.dynamic_config.batch_stats[bs].keys()) == \ - list(self.dynamic_config.batch_stats[bs].keys()), \ - f"batch size {bs} draft keys must be sorted" - - - - def get_optimal_num_speculative_tokens(self, batch_size: int) -> int: - return self.optimal_num_speculative_tokens[batch_size] - - - def update_optimal_num_speculative_tokens(self): - self.optimal_num_speculative_tokens = { - bs: self._compute_optimal_num_speculative_tokens(bs) \ - for bs in range(1, self.vllm_max_batch_size) - } - - - def _get_batch_stats(self, batch_size: int) -> dict: - # import pdb; pdb.set_trace() - if batch_size not in self.batch_stats: - # find the nearest batch size smaller and bigger than the given batch size - # and return the weighted avg of their stats - print(f"Finding batch stats for batch_size: {batch_size} in self.available_batch_sizes: {self.available_batch_sizes}") - - smaller_bs = [bs for bs in self.available_batch_sizes if bs < batch_size] - smaller_bs = max(smaller_bs) if len(smaller_bs) else self.available_batch_sizes[0] - larger_bs = [bs for bs in self.available_batch_sizes if bs > batch_size] - larger_bs = min(larger_bs) if len(larger_bs) else self.available_batch_sizes[-1] - - - # REMOVE - print(f"smaller_bs: {smaller_bs}, larger_bs: {larger_bs}, batch_size: {batch_size}") - - smaller_bs_stat = self.batch_stats[smaller_bs] - larger_bs_stat = self.batch_stats[larger_bs] - - ratio = (batch_size - smaller_bs) / (larger_bs - smaller_bs) - - # REMOVE - print(f"ratio: {ratio}") - - avg_stat: dict[int, float] = {} - for k in smaller_bs_stat.keys(): - avg_stat[k] = smaller_bs_stat[k] + ratio * (larger_bs_stat[k] - smaller_bs_stat[k]) - - return avg_stat - else: - return self.batch_stats[batch_size] - - - def _get_itl(self, batch_stats, num_drafts: int) -> float: - if num_drafts in batch_stats: - return batch_stats[num_drafts] - else: - lower_num_draft = max(k for k in batch_stats.keys() if k < num_drafts) - upper_num_draft = min(k for k in batch_stats.keys() if k > num_drafts) - - # REMOVE - # print(f"lower_num_draft: {lower_num_draft}, upper_num_draft: {upper_num_draft}, num_drafts: {num_drafts}") - - ratio = (num_drafts - lower_num_draft) / (upper_num_draft - lower_num_draft) - lower_itl = batch_stats[lower_num_draft] - upper_itl = batch_stats[upper_num_draft] - return lower_itl + ratio * (upper_itl - lower_itl) - - - def _compute_optimal_num_speculative_tokens(self, batch_size: int) -> int: - batch_stats = self._get_batch_stats(batch_size) - - max_goodput = -1 - for num_drafts in range(self.dynamic_sd_stats["max_num_speculative_tokens"] + 1): - curr_al = 1 + sum(self.dynamic_sd_stats["acceptance_rate_per_pos"][:num_drafts]) - curr_itl = self._get_itl(batch_stats, num_drafts) - curr_goodput = curr_al / curr_itl - if curr_goodput > max_goodput: - max_goodput = curr_goodput - chosen_num_drafts = num_drafts - - # REMOVE - print(f"num_drafts: {num_drafts}, al: {curr_al}, itl: {curr_itl}, goodput: {curr_goodput}") - - return chosen_num_drafts - - -if __name__ == "__main__": - # print(_get_batch_stats(21)) - dynamic_sd = DynamicSpeculativeDecodingManager(_DYNAMIC_STATS) - for i in range(4, 64, 4): - print("\n====================================") - print(f"bs: {i}, optimal num drafts: {dynamic_sd.get_optimal_num_speculative_tokens(i)}") \ No newline at end of file diff --git a/vllm/v1/spec_decode/dynamic/dynamic.py b/vllm/v1/spec_decode/dynamic/dynamic.py new file mode 100644 index 000000000000..910609808cf9 --- /dev/null +++ b/vllm/v1/spec_decode/dynamic/dynamic.py @@ -0,0 +1,169 @@ +from typing import Any, Tuple, Optional, List +from vllm.config.speculative import DynamicSpeculativeConfig + + +# _DYNAMIC_STATS = { +# "max_num_speculative_tokens": 7, +# "acceptance_rate_per_pos": [0.68, 0.39, 0.20, 0.10, 0.06, 0.03, 0.02], # E1 +# # "acceptance_rate_per_pos": [0.76, 0.54, 0.39, 0.28, 0.21, 0.15, 0.12], # E3 +# # "acceptance_rate_per_pos": [0.17, 0.15, 0.12, 0.01, 0.01, 0.01, 0.01], # E1 - low +# "batch_stats": { +# 1: { 0: 6.87, 1: 7.97, 3: 9.41, 4: 9.91, 5: 10.8, 7: 12.29, }, +# 4: { 0: 6.87, 1: 7.97, 3: 9.41, 4: 9.91, 5: 10.8, 7: 12.29, }, +# 16: { 0: 7.3, 1: 8.39, 3: 9.95, 4: 10.8, 5: 11.59, 7: 13.11, }, +# 32: { 0: 7.64, 1: 8.97, 3: 10.78, 4: 11.79, 5: 12.81, 7: 14.86, }, +# 64: { 0: 8.53, 1: 10.44, 3: 13.16, 4: 15.7, 5: 17.54, 7: 120.57, }, +# } +# } + +_DYNAMIC_STATS = DynamicSpeculativeConfig( + max_num_speculative_tokens=7, + acceptance_rate_per_pos=[0.68, 0.39, 0.20, 0.10, 0.06, 0.03, 0.02], # E1 + # acceptance_rate_per_pos=[0.76, 0.54, 0.39, 0.28, 0.21, 0.15, 0.12], # E3 + # acceptance_rate_per_pos=[0.17, 0.15, 0.12, 0.01, 0.01, 0.01, 0.01], # E1 - low + batch_stats={ + 1: {0: 6.87, 1: 7.97, 3: 9.41, 4: 9.91, 5: 10.8, 7: 12.29,}, + 4: {0: 6.87, 1: 7.97, 3: 9.41, 4: 9.91, 5: 10.8, 7: 12.29,}, + 16: {0: 7.3, 1: 8.39, 3: 9.95, 4: 10.8, 5: 11.59, 7: 13.11,}, + 32: {0: 7.64, 1: 8.97, 3: 10.78, 4: 11.79, 5: 12.81, 7: 14.86,}, + 64: {0: 8.53, 1: 10.44, 3: 13.16, 4: 15.7, 5: 17.54, 7: 120.57,}, + 128: {0: 8.53, 1: 15.44, 3: 25.16, 4: 30.7, 5: 37.54, 7: 220.57,}, # fake + } +) + +class DynamicSpeculativeDecodingManager: + + """A mapping from batch size to optimal number of drafts to use for that + batch size. This is used to dynamically adjust the number of drafts used + based on the current batch size.""" + _optimal_num_speculative_tokens: dict[int, int] + + def __init__(self, + dynamic_config: Optional[DynamicSpeculativeConfig], + vllm_max_batch_size: int, + vllm_num_speculative_tokens: int): + self.dynamic_config = dynamic_config + self.vllm_max_batch_size = vllm_max_batch_size + self.batch_stats = self.dynamic_config.batch_stats + self.available_batch_sizes = sorted(self.dynamic_config.batch_stats.keys()) + + # Sanity check + assert vllm_num_speculative_tokens <= self.dynamic_config.max_num_speculative_tokens, \ + "vllm_num_speculative_tokens must be <= max_num_speculative_tokens" + + # if self.dynamic_config.is_online: + assert self.dynamic_config.max_num_speculative_tokens == len(self.dynamic_config.acceptance_rate_per_pos), \ + "max_num_speculative_tokens must be equal to the length of acceptance_rate_per_pos" + assert self.dynamic_config.max_num_speculative_tokens > 0, "max_num_speculative_tokens must be > 0" + assert all(0 < a < 1 for a in self.dynamic_config.acceptance_rate_per_pos), "all acceptance_rate_per_pos values must be in (0, 1)" + assert 1 in self.dynamic_config.batch_stats, f"batch size 1 must be available, found: {self.dynamic_config.batch_stats.keys()}" + assert vllm_max_batch_size in self.dynamic_config.batch_stats, \ + f"vllm max_num_seqs {vllm_max_batch_size} must be available, found: {self.dynamic_config.batch_stats.keys()}" + + for bs in self.available_batch_sizes: + assert bs > 0 + assert 0 in self.dynamic_config.batch_stats[bs], \ + f"batch size {bs} must have draft 0 stats" + assert 1 in self.dynamic_config.batch_stats[bs], \ + f"batch size {bs} must have draft 1 stats" + assert sorted(self.dynamic_config.batch_stats[bs].keys()) == \ + list(self.dynamic_config.batch_stats[bs].keys()), \ + f"batch size {bs} draft keys must be sorted" + + self.update_optimal_num_speculative_tokens() + + + + def get_optimal_num_speculative_tokens(self, batch_size: int) -> int: + assert batch_size > 0, "batch_size must be > 0" + assert batch_size <= self.vllm_max_batch_size, \ + "batch_size must be <= vllm_max_batch_size" + return self._optimal_num_speculative_tokens[batch_size] + + + def update_optimal_num_speculative_tokens(self): + self._optimal_num_speculative_tokens = { + bs: self._compute_optimal_num_speculative_tokens(bs) \ + for bs in range(1, self.vllm_max_batch_size + 1) + } + + + def _get_batch_stats(self, batch_size: int) -> dict: + # import pdb; pdb.set_trace() + if batch_size not in self.batch_stats: + # find the nearest batch size smaller and bigger than the given batch size + # and return the weighted avg of their stats + print(f"Finding batch stats for batch_size: {batch_size} in self.available_batch_sizes: {self.available_batch_sizes}") + + smaller_bs = [bs for bs in self.available_batch_sizes if bs < batch_size] + smaller_bs = max(smaller_bs) if len(smaller_bs) else self.available_batch_sizes[0] + larger_bs = [bs for bs in self.available_batch_sizes if bs > batch_size] + larger_bs = min(larger_bs) if len(larger_bs) else self.available_batch_sizes[-1] + + + # REMOVE + print(f"smaller_bs: {smaller_bs}, larger_bs: {larger_bs}, batch_size: {batch_size}") + + smaller_bs_stat = self.batch_stats[smaller_bs] + larger_bs_stat = self.batch_stats[larger_bs] + + ratio = (batch_size - smaller_bs) / (larger_bs - smaller_bs) + + # REMOVE + print(f"ratio: {ratio}") + + avg_stat: dict[int, float] = {} + for k in smaller_bs_stat.keys(): + avg_stat[k] = smaller_bs_stat[k] + ratio * (larger_bs_stat[k] - smaller_bs_stat[k]) + + return avg_stat + else: + return self.batch_stats[batch_size] + + + def _get_itl(self, batch_stats, num_drafts: int) -> float: + if num_drafts in batch_stats: + return batch_stats[num_drafts] + else: + lower_num_draft = max(k for k in batch_stats.keys() if k < num_drafts) + upper_num_draft = min(k for k in batch_stats.keys() if k > num_drafts) + + # REMOVE + # print(f"lower_num_draft: {lower_num_draft}, upper_num_draft: {upper_num_draft}, num_drafts: {num_drafts}") + + ratio = (num_drafts - lower_num_draft) / (upper_num_draft - lower_num_draft) + lower_itl = batch_stats[lower_num_draft] + upper_itl = batch_stats[upper_num_draft] + return lower_itl + ratio * (upper_itl - lower_itl) + + + def _compute_optimal_num_speculative_tokens(self, batch_size: int) -> int: + batch_stats = self._get_batch_stats(batch_size) + + max_goodput = -1 + for num_drafts in range(self.dynamic_config.max_num_speculative_tokens + 1): + curr_al = 1 + sum(self.dynamic_config.acceptance_rate_per_pos[:num_drafts]) + curr_itl = self._get_itl(batch_stats, num_drafts) + curr_goodput = curr_al / curr_itl + if curr_goodput > max_goodput: + max_goodput = curr_goodput + chosen_num_drafts = num_drafts + + # REMOVE + print(f"num_drafts: {num_drafts}, al: {curr_al}, itl: {curr_itl}, goodput: {curr_goodput}") + + return chosen_num_drafts + + +# python3 vllm/v1/spec_decode/dynamic.py +if __name__ == "__main__": + # print(_get_batch_stats(21)) + MAX_TEST_BS = 128 + dynamic_sd = DynamicSpeculativeDecodingManager( + dynamic_config=_DYNAMIC_STATS, + vllm_max_batch_size=MAX_TEST_BS, + vllm_num_speculative_tokens=7, + ) + for i in range(4, MAX_TEST_BS+1, 4): + print("\n====================================") + print(f"bs: {i}, optimal num drafts: {dynamic_sd.get_optimal_num_speculative_tokens(i)}") \ No newline at end of file diff --git a/vllm/v1/spec_decode/dynamic/online_profiling_client.py b/vllm/v1/spec_decode/dynamic/online_profiling_client.py new file mode 100644 index 000000000000..3e9baae37bdc --- /dev/null +++ b/vllm/v1/spec_decode/dynamic/online_profiling_client.py @@ -0,0 +1,197 @@ +import os +import subprocess +import pandas as pd +from dataclasses import dataclass +import argparse +from vllm.v1.spec_decode.dynamic.online_profiling_server import ( + start_server, + kill_server, + setup_server) + + +@dataclass +class Dataset: + name: str + config: list + +NGRAM_FMT = "min-{min}-max-{max}-k-{k}" +EAGLE_FMT = "k-{k}" + +def run_command(command): + try: + result = subprocess.run(f"bash -c '{command}'", + shell=True, + check=True, + capture_output=True, + text=True) + print("Output:") + print(result.stdout) + except subprocess.CalledProcessError as e: + print("Error:") + print(e.stderr) + + +def run_benchmarks(args): + + # setup server + setup_server() + + port=9001 + all_sampling_profile=[ + {'temperature': 0, 'topp': 1}, # greedy + ] + + # `num_batches` decides how many batches are sent for each concurrency. + # E.g., num_batches=20 and concurrency=4 means total 80 prompts are sent + # such that we send 20 batches of 4 prompts each. This ensures a consistent + # number of batches across different concurrencies. For e.g., if total + # samples is 80 then concurrency 1 will send 80 batches while concurrency 64 + # will send 2 batches only. + MTBENCH_CONFIG = [{"num_batches": 20}] + + all_bench_dataset = [ + Dataset(name = "philschmid/mt-bench", config = MTBENCH_CONFIG), + ] + + assert (all(len(ds.config) > 0 for ds in all_bench_dataset)), "Each dataset must have at least one config" + + all_ngram_params = [{"min": 2, "max": 5, "k": k} for k in args.num_speculative_tokens_list] + all_eagle_params = args.num_speculative_tokens_list + + # ablation + num_exp_run = 0 + for tp in args.tp_list: + spec_method = args.method + # collate all spec configs to run for a given method + all_spec_config = [] + if spec_method == "ngram": + for ngram_params in all_ngram_params: + all_spec_config.append({ + "method": "ngram", + "num_speculative_tokens": ngram_params['k'], + "prompt_lookup_max": ngram_params['max'], + "prompt_lookup_min": ngram_params['min'], + }) + elif spec_method == "eagle": + for eagle_k in all_eagle_params: + all_spec_config.append({ + "method": "eagle", + "model": args.draft_dir, + "num_speculative_tokens": eagle_k, + "draft_tensor_parallel_size": tp, + }) + else: + # vanilla case + all_spec_config.append(None) + + for spec_config in all_spec_config: + # start server + server_process = start_server(port=port, + target_model_dir=args.model_dir, + spec_config=spec_config, + tp=tp, + max_vllm_bs=args.max_vllm_batch_size, + dry_run=args.dry_run) + + # start client + for bench_concurrency in args.batch_size_list: + for bench_dataset_object in all_bench_dataset: + bench_dataset = bench_dataset_object.name + for bench_config in bench_dataset_object.config: + for sampling_profile in all_sampling_profile: + bench_temperature = sampling_profile['temperature'] + bench_topp = sampling_profile['topp'] + + spec_config_str = "vanilla" + if spec_method == "ngram": + spec_config_str = NGRAM_FMT.format( + min=spec_config['prompt_lookup_min'], + max=spec_config['prompt_lookup_max'], + k=spec_config['num_speculative_tokens'] + ) + elif spec_method == "eagle": + spec_config_str = EAGLE_FMT.format( + k=spec_config['num_speculative_tokens'] + ) + + # dataset specific config + if "philschmid/mt-bench" in bench_dataset: + bench_config_str = f"mt_bench" + num_prompts = bench_config["num_batches"] * bench_concurrency + bench_vllm_serve_config = f'--dataset-name hf --dataset-path {bench_dataset} --num-prompts {num_prompts}' + + print(f"Number of prompts in {bench_dataset}: {num_prompts}") + + # create dir if not exists + result_dir = f"{args.result_dir}/tp-{tp}_temp-{bench_temperature}_top_p-{bench_topp}/{bench_dataset}/{spec_method}/" + if not os.path.exists(result_dir): + os.makedirs(result_dir) + + cmd = f'''time vllm bench serve --port {port} --save-result --save-detailed \ + --model {args.model_dir} \ + --backend openai-chat \ + --endpoint /v1/chat/completions \ + {bench_vllm_serve_config} \ + --max-concurrency {bench_concurrency} \ + --temperature={bench_temperature} \ + --top-p={bench_topp} \ + --result-dir "{result_dir}" \ + --result-filename "{spec_config_str}_{bench_config_str}_bs-{bench_concurrency}_{args.extra_log_arg}.txt"''' + + print(cmd) + num_exp_run += 1 + + if not args.dry_run: + run_command(cmd) + + # server teardown: kill server and any gpu processes + kill_server(port, server_process) + + print(f"Total number of experiments run: {num_exp_run}") + + +""" +# eagle +time python3 vllm/v1/spec_decode/online_profiling_client.py \ + --batch-size-list 1 4 16 64 256 \ + --num-speculative-tokens-list 1 3 5 \ + --max-vllm-batch-size 256 \ + --method eagle \ + --model-dir meta-llama/Llama-3.1-8B-Instruct \ + --draft-dir yuhuili/EAGLE-LLaMA3.1-Instruct-8B + +# vanilla +time python3 vllm/v1/spec_decode/online_profiling_client.py \ + --batch-size-list 1 4 16 64 256 \ + --max-vllm-batch-size 256 \ + --method vanilla +""" +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dry-run", action="store_true", help="Run in dry run mode. If set, commands will be printed but not executed.") + parser.add_argument("--model-dir", type=str, default=None) + parser.add_argument("--draft-dir", type=str, default=None) + # parser.add_argument("--method-list", nargs='*', type=str, default=["vanilla", "eagle"]) + parser.add_argument("--method", type=str, default="vanilla") + parser.add_argument("--num-speculative-tokens-list", nargs='*', type=int, default=[1, 3, 5]) + parser.add_argument("--batch-size-list", nargs='*', type=int, default=[1, 4, 16, 64, 256]) + parser.add_argument("--max-vllm-batch-size", type=int, help="Max vllm server batch size (max concurrency)") + parser.add_argument("--tp-list", nargs='*', type=int, default=[1]) + parser.add_argument("--result-dir", type=str, default="./log/dynamic_sd") + parser.add_argument("--extra-log-arg", type=str, default="") + args = parser.parse_args() + + # assert all([method in ["vanilla", "ngram", "eagle", "eagle3"] for method in args.method_list]), \ + # "invalid method in method_list" + # assert 0 < len(args.method_list) <= 2 + # if len(args.method_list) == 2: + # assert "vanilla" in args.method_list, "If two methods are specified, one must be vanilla" + assert args.method in ["vanilla", "ngram", "eagle", "eagle3"], \ + "invalid method specified" + + # assert 1 in args.batch_size_list, "batch_size must contain 1" + # assert 1 in args.num_speculative_tokens_list, "num_speculative_tokens must contain 1" + assert args.max_vllm_batch_size == max(args.batch_size_list), \ + "max_vllm_batch_size must be equal to max of batch_size" + + run_benchmarks(args) \ No newline at end of file diff --git a/vllm/v1/spec_decode/online_profiling_server.py b/vllm/v1/spec_decode/dynamic/online_profiling_server.py similarity index 81% rename from vllm/v1/spec_decode/online_profiling_server.py rename to vllm/v1/spec_decode/dynamic/online_profiling_server.py index 7d512079eb7b..b1a1fff6a473 100644 --- a/vllm/v1/spec_decode/online_profiling_server.py +++ b/vllm/v1/spec_decode/dynamic/online_profiling_server.py @@ -3,6 +3,7 @@ import json import time import shutil +import signal """ Utility functions to manage the vLLM server for online profiling. @@ -31,6 +32,10 @@ def kill_gpu_processes(port: int): # Do not kill this process filename = os.path.basename(__file__) + + # REMOVE + print(f"filename to exclude: {filename}") + pids_to_kill = [] for line in ps_output.split("\n"): if "python3" in line and filename not in line: @@ -41,6 +46,11 @@ def kill_gpu_processes(port: int): for pid in pids_to_kill: subprocess.run(["kill", "-9", pid]) + wait_for_gpu_memory_to_clear() + + # subprocess.run(["rm", "-rf", "~/.config/vllm"]) + +def wait_for_gpu_memory_to_clear(): # Wait until all GPUs have memory usage < 1000 MB if shutil.which("nvidia-smi"): while True: @@ -67,9 +77,6 @@ def kill_gpu_processes(port: int): break time.sleep(1) - subprocess.run(["rm", "-rf", "~/.config/vllm"]) - - def setup_server(): # install dependencies dependencies = ["lsof", "curl", "pgrep"] @@ -103,7 +110,7 @@ def start_server(port: int, # start vllm server if not dry_run: - server_process = subprocess.Popen(server_command, shell=True) + server_process = subprocess.Popen(server_command, shell=True, preexec_fn=os.setsid) if wait_for_server(port): print("vllm server is up and running.") @@ -116,7 +123,28 @@ def start_server(port: int, return None -def kill_server(port: int, server_process: subprocess.Popen | None): +# def kill_server(port: int, server_process: subprocess.Popen | None): +# if server_process: +# server_process.kill() +# kill_gpu_processes(port) + +def kill_server(port, server_process): + + # REMOVE + # print(f"Killing server on port {port}...") + if server_process: - server_process.kill() - kill_gpu_processes(port) \ No newline at end of file + os.killpg(os.getpgid(server_process.pid), signal.SIGKILL) + + # REMOVE + # print(f"Killed server process with PID: {server_process.pid if server_process else 'N/A'}") + + wait_for_gpu_memory_to_clear() + + # Clean vLLM config + config_path = os.path.expanduser("~/.config/vllm") + if os.path.exists(config_path): + subprocess.run(["rm", "-rf", config_path]) + + # REMOVE + # print(f"Killed server on port {port} and cleaned up GPU processes.") \ No newline at end of file diff --git a/vllm/v1/spec_decode/dynamic_profiling.py b/vllm/v1/spec_decode/dynamic/process_benchmark_results.py similarity index 65% rename from vllm/v1/spec_decode/dynamic_profiling.py rename to vllm/v1/spec_decode/dynamic/process_benchmark_results.py index abd57915628d..34d688571993 100644 --- a/vllm/v1/spec_decode/dynamic_profiling.py +++ b/vllm/v1/spec_decode/dynamic/process_benchmark_results.py @@ -2,7 +2,18 @@ import os import json import argparse -from vllm.v1.spec_decode.online_profiling_client import (NGRAM_FMT, EAGLE_FMT) +from vllm.v1.spec_decode.dynamic.online_profiling_client import (NGRAM_FMT, EAGLE_FMT) + + +def reverse_fmt(fmt_str): + # e.g., convert 'min-{min}-max-{max}-k-{k}' -> 'min-{}-max-{}-k-{}' + FMT = re.sub(r"\{[^}]+\}", "{}", fmt_str) + # e.g., convert 'min-{}-max-{}-k-{}' -> 'min-(.+)-max-(.+)-k-(.+)' + FMT = FMT.replace("{}", "(.+)") + return FMT + +NGRAM_FMT_REVERSE = reverse_fmt(NGRAM_FMT) +EAGLE_FMT_REVERSE = reverse_fmt(EAGLE_FMT) def parse_itl(args): @@ -31,12 +42,12 @@ def parse_itl(args): # find sd params spec_config_str = log_file.split("_")[0] - if method == "ngram": - FMT = NGRAM_FMT.replace("{}", "(.+)") - min, max, k = re.match(FMT, spec_config_str).groups() + if method == "vanilla": + k=0 + elif method == "ngram": + min, max, k = re.match(NGRAM_FMT_REVERSE, spec_config_str).groups() elif method == "eagle": - FMT = EAGLE_FMT.replace("{}", "(.+)") - k = re.match(FMT, spec_config_str).groups()[0] + k = re.match(EAGLE_FMT_REVERSE, spec_config_str).groups()[0] # read the log file to get the itl with open(os.path.join(args.benchmark_path, log_file), "r") as f: @@ -51,7 +62,11 @@ def parse_itl(args): return batch_stats - +""" +python3 vllm/v1/spec_decode/process_benchmark_results.py \ + --sd-method eagle \ + --benchmark-path-parent 'log/dynamic_sd/tp-1_temp-0_top_p-1/philschmid/mt-bench/' +""" if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--sd-method", type=str, default=None) @@ -60,4 +75,5 @@ def parse_itl(args): args = parser.parse_args() assert args.sd_method in ["ngram", "eagle"], "Invalid method specified." - batch_stats = parse_itl(args) \ No newline at end of file + batch_stats = parse_itl(args) + print(json.dumps(batch_stats, indent=4)) \ No newline at end of file diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index e4270bd2fcd5..7bfa228c25ce 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -34,7 +34,7 @@ from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.ubatching import dbo_current_ubatch_id -from vllm.vllm.model_executor.models.interfaces import SpeculativeDecodingProposer +from vllm.model_executor.models.interfaces import SpeculativeDecodingProposer logger = init_logger(__name__) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index cfe9070b90cc..b8c796ffa7c6 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -7,7 +7,7 @@ from typing import Optional from vllm.config import VllmConfig -from vllm.vllm.model_executor.models.interfaces import SpeculativeDecodingProposer +from vllm.model_executor.models.interfaces import SpeculativeDecodingProposer class NgramProposer(SpeculativeDecodingProposer): diff --git a/vllm/v1/spec_decode/online_profiling_client.py b/vllm/v1/spec_decode/online_profiling_client.py deleted file mode 100644 index be4e93e39f85..000000000000 --- a/vllm/v1/spec_decode/online_profiling_client.py +++ /dev/null @@ -1,185 +0,0 @@ -import os -import subprocess -import pandas as pd -from dataclasses import dataclass -import argparse -from vllm.vllm.v1.spec_decode.online_profiling_server import ( - start_server, - kill_server, - setup_server) - - -@dataclass -class Dataset: - name: str - config: list - -NGRAM_FMT = "min-{min}-max-{max}-k-{k}" -EAGLE_FMT = "k-{k}" - -def run_command(command): - try: - result = subprocess.run(f"bash -c '{command}'", - shell=True, - check=True, - capture_output=True, - text=True) - print("Output:") - print(result.stdout) - except subprocess.CalledProcessError as e: - print("Error:") - print(e.stderr) - - -def run_benchmarks(args): - - # setup server - setup_server() - - port=9001 - all_sampling_profile=[ - {'temperature': 0, 'topp': 1}, # greedy - ] - - MTBENCH_CONFIG = [{"num_samples_per_seq": 20}] - - all_bench_dataset = [ - Dataset(name = "philschmid/mt-bench", config = MTBENCH_CONFIG), - ] - - assert (all(len(ds.config) > 0 for ds in all_bench_dataset)), "Each dataset must have at least one config" - - all_ngram_params = [{"min": 2, "max": 5, "k": k} for k in args.num_speculative_tokens] - all_eagle_params = args.num_speculative_tokens - - # ablation - num_exp_run = 0 - for tp in args.tp: - for spec_method in args.method_list: - # collate all spec configs to run for a given method - all_spec_config = [] - if spec_method == "ngram": - for ngram_params in all_ngram_params: - all_spec_config.append({ - "method": "ngram", - "num_speculative_tokens": ngram_params['k'], - "prompt_lookup_max": ngram_params['max'], - "prompt_lookup_min": ngram_params['min'], - }) - elif spec_method == "eagle": - for eagle_k in all_eagle_params: - all_spec_config.append({ - "method": "eagle", - "model": args.draft_dir, - "num_speculative_tokens": eagle_k, - "draft_tensor_parallel_size": tp, - }) - else: - # vanilla case - all_spec_config.append(None) - - for spec_config in all_spec_config: - # start server - server_process = start_server(port=port, - target_model_dir=args.model_dir, - spec_config=spec_config, - tp=tp, - max_vllm_bs=args.max_vllm_batch_size, - dry_run=args.dry_run) - - # start client - for bench_concurrency in args.batch_size_list: - for bench_dataset_object in all_bench_dataset: - bench_dataset = bench_dataset_object.name - for bench_config in bench_dataset_object.config: - for sampling_profile in all_sampling_profile: - bench_temperature = sampling_profile['temperature'] - bench_topp = sampling_profile['topp'] - - spec_config_str = "vanilla" - if spec_method == "ngram": - spec_config_str = NGRAM_FMT.format( - min=spec_config['prompt_lookup_min'], - max=spec_config['prompt_lookup_max'], - k=spec_config['num_speculative_tokens'] - ) - elif spec_method == "eagle": - spec_config_str = EAGLE_FMT.format( - k=spec_config['num_speculative_tokens'] - ) - - # dataset specific config - if "philschmid/mt-bench" in bench_dataset: - bench_config_str = f"mt_bench" - num_prompts = bench_config["num_prompts"] * bench_concurrency - bench_vllm_serve_config = f'--dataset-name hf --dataset-path {bench_dataset} --num-prompts {num_prompts}' - - print(f"Number of prompts in {bench_dataset}: {num_prompts}") - - # create dir if not exists - result_dir = f"{args.result_dir}/tp-{tp}_temp-{bench_temperature}_top_p-{bench_topp}/{bench_dataset}/{spec_method}/online/" - if not os.path.exists(result_dir): - os.makedirs(result_dir) - - cmd = f'''time vllm bench serve --port {port} --save-result --save-detailed \ - --model {args.model_dir} \ - --backend openai-chat \ - --endpoint /v1/chat/completions \ - {bench_vllm_serve_config} \ - --max-concurrency {bench_concurrency} \ - --temperature={bench_temperature} \ - --top-p={bench_topp} \ - --result-dir "{result_dir}" \ - --result-filename "{spec_config_str}_{bench_config_str}_bs-{bench_concurrency}_{args.extra_log_arg}.txt"''' - - print(cmd) - num_exp_run += 1 - - if not args.dry_run: - run_command(cmd) - - # server teardown: kill server and any gpu processes - kill_server(port, server_process) - - print(f"Total number of experiments run: {num_exp_run}") - - -# time python3 vllm/cohere/utils/eagle/sweep_ngram_eagle_online_benchmark.py -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--dry-run", action="store_true", help="Run in dry run mode. If set, commands will be printed but not executed.") - parser.add_argument("--model-dir", type=str, default=None) - parser.add_argument("--draft-dir", type=str, default=None) - parser.add_argument("--method-list", type=list[str], default=["vanilla", "eagle"]) - parser.add_argument("--num-speculative-tokens-list", type=list[int], default=[1, 3, 5]) - parser.add_argument("--batch-size-list", type=list[int], default=[1, 4, 8, 16, 32, 64, 128]) - parser.add_argument("--max-vllm-batch-size", type=int, help="Max vllm server batch size (max concurrency)") - parser.add_argument("--tp-list", type=list[int], default=[1]) - parser.add_argument("--result-dir", type=str, default="./log") - parser.add_argument("--extra-log-arg", type=str, default="") - args = parser.parse_args() - - assert all([method in ["vanilla", "ngram", "eagle", "eagle3"] for method in args.method_list]), \ - "invalid method in method_list" - - assert 1 in args.batch_size_list, "batch_size must contain 1" - assert 1 in args.num_speculative_tokens_list, "num_speculative_tokens must contain 1" - assert args.max_vllm_batch_size == max(args.batch_size_list), \ - "max_vllm_batch_size must be equal to max of batch_size" - - model_dir = args.model_dir - args.model_dir = "meta-llama/Llama-3.1-8B-Instruct" if args.model_dir is None else args.model_dir - - if args.method == "eagle" or args.method == "eagle3": - if args.method == "eagle" and args.eagle_dir is None: - args.eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" - - elif args.method == "eagle3" and args.eagle_dir is None: - args.eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" - speculative_config = { - "method": args.method, - "model": args.eagle_dir, - "num_speculative_tokens": args.num_spec_tokens, - } - - run_benchmarks(args) \ No newline at end of file diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bac1d5b7177c..a9f4c700ad02 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -299,10 +299,13 @@ def __init__( self.rejection_sampler = RejectionSampler() # setup Dynamic Speculative Decoding - self.dynamic_sd_manager = DynamicSpeculativeDecodingManager( - self.speculative_config.dynamic_config, - self.vllm_config.scheduler_config.max_num_seqs, - self.vllm_config.speculative_config.num_speculative_tokens) + if self.speculative_config.dynamic_config: + self.dynamic_sd_manager = DynamicSpeculativeDecodingManager( + self.speculative_config.dynamic_config, + self.vllm_config.scheduler_config.max_num_seqs, + self.vllm_config.speculative_config.num_speculative_tokens) + else: + self.dynamic_sd_manager = None # Request states. self.requests: dict[str, CachedRequestState] = {} From 06916e5a9e335fb5c06de3ebd78d2bf709958af3 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Sun, 4 Jan 2026 04:33:21 +0000 Subject: [PATCH 03/39] pre Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/config/speculative.py | 8 +- vllm/v1/spec_decode/dynamic/dynamic.py | 175 ++++++++++++------ .../dynamic/online_profiling_client.py | 140 ++++++++------ .../dynamic/online_profiling_server.py | 70 ++++--- .../dynamic/process_benchmark_results.py | 43 +++-- vllm/v1/spec_decode/ngram_proposer.py | 6 +- vllm/v1/worker/gpu_model_runner.py | 23 ++- 7 files changed, 298 insertions(+), 167 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 7f816215eaf9..79f7e1435b6f 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -48,14 +48,16 @@ EagleModelTypes, ] + @dataclass class DynamicSpeculativeConfig: # """A mapping from batch size to optimal number of drafts to use for that # batch size. This is used to dynamically adjust the number of drafts used # based on the current batch size.""" # optimal_num_speculative_tokens: dict[int, int] = None - + """Whether the statistics are updated online or not during inference.""" + is_online: bool = False """ @@ -82,7 +84,7 @@ class DynamicSpeculativeConfig: """Acceptance rate per position on an offline dataset.""" acceptance_rate_per_pos: list[float] = None - + @config @dataclass @@ -156,7 +158,7 @@ class SpeculativeConfig: # dynamic speculative decoding control """Configuration for dynamic speculative decoding, if provided.""" - dynamic_config: Optional[DynamicSpeculativeConfig] = None + dynamic_config: DynamicSpeculativeConfig | None = None # params generated in the post-init stage draft_model_config: SkipValidation[ModelConfig] = None # type: ignore diff --git a/vllm/v1/spec_decode/dynamic/dynamic.py b/vllm/v1/spec_decode/dynamic/dynamic.py index a9caab80eda8..91c3b4223f9e 100644 --- a/vllm/v1/spec_decode/dynamic/dynamic.py +++ b/vllm/v1/spec_decode/dynamic/dynamic.py @@ -1,10 +1,10 @@ -from typing import Any, Tuple, Optional, List +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.config.speculative import DynamicSpeculativeConfig - # _DYNAMIC_STATS = { # "max_num_speculative_tokens": 7, -# "acceptance_rate_per_pos": [0.68, 0.39, 0.20, 0.10, 0.06, 0.03, 0.02], # E1 +# "acceptance_rate_per_pos": [0.68, 0.39, 0.20, 0.10, 0.06, 0.03, 0.02], # E1 # # "acceptance_rate_per_pos": [0.76, 0.54, 0.39, 0.28, 0.21, 0.15, 0.12], # E3 # # "acceptance_rate_per_pos": [0.17, 0.15, 0.12, 0.01, 0.01, 0.01, 0.01], # E1 - low # "batch_stats": { @@ -18,91 +18,154 @@ _DYNAMIC_STATS = DynamicSpeculativeConfig( max_num_speculative_tokens=7, - acceptance_rate_per_pos=[0.68, 0.39, 0.20, 0.10, 0.06, 0.03, 0.02], # E1 + acceptance_rate_per_pos=[0.68, 0.39, 0.20, 0.10, 0.06, 0.03, 0.02], # E1 # acceptance_rate_per_pos=[0.76, 0.54, 0.39, 0.28, 0.21, 0.15, 0.12], # E3 # acceptance_rate_per_pos=[0.17, 0.15, 0.12, 0.01, 0.01, 0.01, 0.01], # E1 - low batch_stats={ - 1: {0: 6.87, 1: 7.97, 3: 9.41, 4: 9.91, 5: 10.8, 7: 12.29,}, - 4: {0: 6.87, 1: 7.97, 3: 9.41, 4: 9.91, 5: 10.8, 7: 12.29,}, - 16: {0: 7.3, 1: 8.39, 3: 9.95, 4: 10.8, 5: 11.59, 7: 13.11,}, - 32: {0: 7.64, 1: 8.97, 3: 10.78, 4: 11.79, 5: 12.81, 7: 14.86,}, - 64: {0: 8.53, 1: 10.44, 3: 13.16, 4: 15.7, 5: 17.54, 7: 120.57,}, - 128: {0: 8.53, 1: 15.44, 3: 25.16, 4: 30.7, 5: 37.54, 7: 220.57,}, # fake - } + 1: { + 0: 6.87, + 1: 7.97, + 3: 9.41, + 4: 9.91, + 5: 10.8, + 7: 12.29, + }, + 4: { + 0: 6.87, + 1: 7.97, + 3: 9.41, + 4: 9.91, + 5: 10.8, + 7: 12.29, + }, + 16: { + 0: 7.3, + 1: 8.39, + 3: 9.95, + 4: 10.8, + 5: 11.59, + 7: 13.11, + }, + 32: { + 0: 7.64, + 1: 8.97, + 3: 10.78, + 4: 11.79, + 5: 12.81, + 7: 14.86, + }, + 64: { + 0: 8.53, + 1: 10.44, + 3: 13.16, + 4: 15.7, + 5: 17.54, + 7: 120.57, + }, + 128: { + 0: 8.53, + 1: 15.44, + 3: 25.16, + 4: 30.7, + 5: 37.54, + 7: 220.57, + }, # fake + }, ) -class DynamicSpeculativeDecodingManager: +class DynamicSpeculativeDecodingManager: """A mapping from batch size to optimal number of drafts to use for that batch size. This is used to dynamically adjust the number of drafts used based on the current batch size.""" + _optimal_num_speculative_tokens: dict[int, int] - def __init__(self, - dynamic_config: Optional[DynamicSpeculativeConfig], - vllm_max_batch_size: int, - vllm_num_speculative_tokens: int): + def __init__( + self, + dynamic_config: DynamicSpeculativeConfig | None, + vllm_max_batch_size: int, + vllm_num_speculative_tokens: int, + ): self.dynamic_config = dynamic_config self.vllm_max_batch_size = vllm_max_batch_size self.batch_stats = self.dynamic_config.batch_stats self.available_batch_sizes = sorted(self.dynamic_config.batch_stats.keys()) # Sanity check - assert vllm_num_speculative_tokens <= self.dynamic_config.max_num_speculative_tokens, \ - "vllm_num_speculative_tokens must be <= max_num_speculative_tokens" - + assert ( + vllm_num_speculative_tokens + <= self.dynamic_config.max_num_speculative_tokens + ), "vllm_num_speculative_tokens must be <= max_num_speculative_tokens" + # if self.dynamic_config.is_online: - assert self.dynamic_config.max_num_speculative_tokens == len(self.dynamic_config.acceptance_rate_per_pos), \ + assert self.dynamic_config.max_num_speculative_tokens == len( + self.dynamic_config.acceptance_rate_per_pos + ), ( "max_num_speculative_tokens must be equal to the length of acceptance_rate_per_pos" - assert self.dynamic_config.max_num_speculative_tokens > 0, "max_num_speculative_tokens must be > 0" - assert all(0 < a < 1 for a in self.dynamic_config.acceptance_rate_per_pos), "all acceptance_rate_per_pos values must be in (0, 1)" - assert 1 in self.dynamic_config.batch_stats, f"batch size 1 must be available, found: {self.dynamic_config.batch_stats.keys()}" - assert vllm_max_batch_size in self.dynamic_config.batch_stats, \ + ) + assert self.dynamic_config.max_num_speculative_tokens > 0, ( + "max_num_speculative_tokens must be > 0" + ) + assert all(0 < a < 1 for a in self.dynamic_config.acceptance_rate_per_pos), ( + "all acceptance_rate_per_pos values must be in (0, 1)" + ) + assert 1 in self.dynamic_config.batch_stats, ( + f"batch size 1 must be available, found: {self.dynamic_config.batch_stats.keys()}" + ) + assert vllm_max_batch_size in self.dynamic_config.batch_stats, ( f"vllm max_num_seqs {vllm_max_batch_size} must be available, found: {self.dynamic_config.batch_stats.keys()}" + ) for bs in self.available_batch_sizes: assert bs > 0 - assert 0 in self.dynamic_config.batch_stats[bs], \ + assert 0 in self.dynamic_config.batch_stats[bs], ( f"batch size {bs} must have draft 0 stats" - assert 1 in self.dynamic_config.batch_stats[bs], \ + ) + assert 1 in self.dynamic_config.batch_stats[bs], ( f"batch size {bs} must have draft 1 stats" - assert sorted(self.dynamic_config.batch_stats[bs].keys()) == \ - list(self.dynamic_config.batch_stats[bs].keys()), \ - f"batch size {bs} draft keys must be sorted" + ) + assert sorted(self.dynamic_config.batch_stats[bs].keys()) == list( + self.dynamic_config.batch_stats[bs].keys() + ), f"batch size {bs} draft keys must be sorted" self.update_optimal_num_speculative_tokens() - - def get_optimal_num_speculative_tokens(self, batch_size: int) -> int: assert batch_size > 0, "batch_size must be > 0" - assert batch_size <= self.vllm_max_batch_size, \ + assert batch_size <= self.vllm_max_batch_size, ( "batch_size must be <= vllm_max_batch_size" + ) return self._optimal_num_speculative_tokens[batch_size] - def update_optimal_num_speculative_tokens(self): self._optimal_num_speculative_tokens = { - bs: self._compute_optimal_num_speculative_tokens(bs) \ - for bs in range(1, self.vllm_max_batch_size + 1) - } - + bs: self._compute_optimal_num_speculative_tokens(bs) + for bs in range(1, self.vllm_max_batch_size + 1) + } def _get_batch_stats(self, batch_size: int) -> dict: # import pdb; pdb.set_trace() if batch_size not in self.batch_stats: # find the nearest batch size smaller and bigger than the given batch size # and return the weighted avg of their stats - print(f"Finding batch stats for batch_size: {batch_size} in self.available_batch_sizes: {self.available_batch_sizes}") - + print( + f"Finding batch stats for batch_size: {batch_size} in self.available_batch_sizes: {self.available_batch_sizes}" + ) + smaller_bs = [bs for bs in self.available_batch_sizes if bs < batch_size] - smaller_bs = max(smaller_bs) if len(smaller_bs) else self.available_batch_sizes[0] + smaller_bs = ( + max(smaller_bs) if len(smaller_bs) else self.available_batch_sizes[0] + ) larger_bs = [bs for bs in self.available_batch_sizes if bs > batch_size] - larger_bs = min(larger_bs) if len(larger_bs) else self.available_batch_sizes[-1] - + larger_bs = ( + min(larger_bs) if len(larger_bs) else self.available_batch_sizes[-1] + ) # REMOVE - print(f"smaller_bs: {smaller_bs}, larger_bs: {larger_bs}, batch_size: {batch_size}") + print( + f"smaller_bs: {smaller_bs}, larger_bs: {larger_bs}, batch_size: {batch_size}" + ) smaller_bs_stat = self.batch_stats[smaller_bs] larger_bs_stat = self.batch_stats[larger_bs] @@ -113,30 +176,30 @@ def _get_batch_stats(self, batch_size: int) -> dict: print(f"ratio: {ratio}") avg_stat: dict[int, float] = {} - for k in smaller_bs_stat.keys(): - avg_stat[k] = smaller_bs_stat[k] + ratio * (larger_bs_stat[k] - smaller_bs_stat[k]) - + for k in smaller_bs_stat: + avg_stat[k] = smaller_bs_stat[k] + ratio * ( + larger_bs_stat[k] - smaller_bs_stat[k] + ) + return avg_stat else: return self.batch_stats[batch_size] - def _get_itl(self, batch_stats, num_drafts: int) -> float: if num_drafts in batch_stats: return batch_stats[num_drafts] else: lower_num_draft = max(k for k in batch_stats.keys() if k < num_drafts) upper_num_draft = min(k for k in batch_stats.keys() if k > num_drafts) - + # REMOVE # print(f"lower_num_draft: {lower_num_draft}, upper_num_draft: {upper_num_draft}, num_drafts: {num_drafts}") - + ratio = (num_drafts - lower_num_draft) / (upper_num_draft - lower_num_draft) lower_itl = batch_stats[lower_num_draft] upper_itl = batch_stats[upper_num_draft] return lower_itl + ratio * (upper_itl - lower_itl) - def _compute_optimal_num_speculative_tokens(self, batch_size: int) -> int: batch_stats = self._get_batch_stats(batch_size) @@ -148,9 +211,11 @@ def _compute_optimal_num_speculative_tokens(self, batch_size: int) -> int: if curr_goodput > max_goodput: max_goodput = curr_goodput chosen_num_drafts = num_drafts - + # REMOVE - print(f"num_drafts: {num_drafts}, al: {curr_al}, itl: {curr_itl}, goodput: {curr_goodput}") + print( + f"num_drafts: {num_drafts}, al: {curr_al}, itl: {curr_itl}, goodput: {curr_goodput}" + ) return chosen_num_drafts @@ -163,6 +228,8 @@ def _compute_optimal_num_speculative_tokens(self, batch_size: int) -> int: vllm_max_batch_size=MAX_TEST_BS, vllm_num_speculative_tokens=7, ) - for i in range(4, MAX_TEST_BS+1, 4): + for i in range(4, MAX_TEST_BS + 1, 4): print("\n====================================") - print(f"bs: {i}, optimal num drafts: {dynamic_sd.get_optimal_num_speculative_tokens(i)}") \ No newline at end of file + print( + f"bs: {i}, optimal num drafts: {dynamic_sd.get_optimal_num_speculative_tokens(i)}" + ) diff --git a/vllm/v1/spec_decode/dynamic/online_profiling_client.py b/vllm/v1/spec_decode/dynamic/online_profiling_client.py index 3e9baae37bdc..4a88723de797 100644 --- a/vllm/v1/spec_decode/dynamic/online_profiling_client.py +++ b/vllm/v1/spec_decode/dynamic/online_profiling_client.py @@ -1,12 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse import os import subprocess -import pandas as pd from dataclasses import dataclass -import argparse + from vllm.v1.spec_decode.dynamic.online_profiling_server import ( + kill_server, + setup_server, start_server, - kill_server, - setup_server) +) @dataclass @@ -14,16 +17,20 @@ class Dataset: name: str config: list + NGRAM_FMT = "min-{min}-max-{max}-k-{k}" EAGLE_FMT = "k-{k}" + def run_command(command): try: - result = subprocess.run(f"bash -c '{command}'", - shell=True, - check=True, - capture_output=True, - text=True) + result = subprocess.run( + f"bash -c '{command}'", + shell=True, + check=True, + capture_output=True, + text=True, + ) print("Output:") print(result.stdout) except subprocess.CalledProcessError as e: @@ -32,13 +39,12 @@ def run_command(command): def run_benchmarks(args): - # setup server setup_server() - port=9001 - all_sampling_profile=[ - {'temperature': 0, 'topp': 1}, # greedy + port = 9001 + all_sampling_profile = [ + {"temperature": 0, "topp": 1}, # greedy ] # `num_batches` decides how many batches are sent for each concurrency. @@ -50,12 +56,16 @@ def run_benchmarks(args): MTBENCH_CONFIG = [{"num_batches": 20}] all_bench_dataset = [ - Dataset(name = "philschmid/mt-bench", config = MTBENCH_CONFIG), + Dataset(name="philschmid/mt-bench", config=MTBENCH_CONFIG), ] - assert (all(len(ds.config) > 0 for ds in all_bench_dataset)), "Each dataset must have at least one config" + assert all(len(ds.config) > 0 for ds in all_bench_dataset), ( + "Each dataset must have at least one config" + ) - all_ngram_params = [{"min": 2, "max": 5, "k": k} for k in args.num_speculative_tokens_list] + all_ngram_params = [ + {"min": 2, "max": 5, "k": k} for k in args.num_speculative_tokens_list + ] all_eagle_params = args.num_speculative_tokens_list # ablation @@ -66,32 +76,38 @@ def run_benchmarks(args): all_spec_config = [] if spec_method == "ngram": for ngram_params in all_ngram_params: - all_spec_config.append({ - "method": "ngram", - "num_speculative_tokens": ngram_params['k'], - "prompt_lookup_max": ngram_params['max'], - "prompt_lookup_min": ngram_params['min'], - }) + all_spec_config.append( + { + "method": "ngram", + "num_speculative_tokens": ngram_params["k"], + "prompt_lookup_max": ngram_params["max"], + "prompt_lookup_min": ngram_params["min"], + } + ) elif spec_method == "eagle": for eagle_k in all_eagle_params: - all_spec_config.append({ - "method": "eagle", - "model": args.draft_dir, - "num_speculative_tokens": eagle_k, - "draft_tensor_parallel_size": tp, - }) + all_spec_config.append( + { + "method": "eagle", + "model": args.draft_dir, + "num_speculative_tokens": eagle_k, + "draft_tensor_parallel_size": tp, + } + ) else: # vanilla case all_spec_config.append(None) for spec_config in all_spec_config: # start server - server_process = start_server(port=port, - target_model_dir=args.model_dir, - spec_config=spec_config, - tp=tp, - max_vllm_bs=args.max_vllm_batch_size, - dry_run=args.dry_run) + server_process = start_server( + port=port, + target_model_dir=args.model_dir, + spec_config=spec_config, + tp=tp, + max_vllm_bs=args.max_vllm_batch_size, + dry_run=args.dry_run, + ) # start client for bench_concurrency in args.batch_size_list: @@ -99,31 +115,35 @@ def run_benchmarks(args): bench_dataset = bench_dataset_object.name for bench_config in bench_dataset_object.config: for sampling_profile in all_sampling_profile: - bench_temperature = sampling_profile['temperature'] - bench_topp = sampling_profile['topp'] + bench_temperature = sampling_profile["temperature"] + bench_topp = sampling_profile["topp"] spec_config_str = "vanilla" if spec_method == "ngram": spec_config_str = NGRAM_FMT.format( - min=spec_config['prompt_lookup_min'], - max=spec_config['prompt_lookup_max'], - k=spec_config['num_speculative_tokens'] + min=spec_config["prompt_lookup_min"], + max=spec_config["prompt_lookup_max"], + k=spec_config["num_speculative_tokens"], ) elif spec_method == "eagle": spec_config_str = EAGLE_FMT.format( - k=spec_config['num_speculative_tokens'] + k=spec_config["num_speculative_tokens"] ) # dataset specific config if "philschmid/mt-bench" in bench_dataset: - bench_config_str = f"mt_bench" - num_prompts = bench_config["num_batches"] * bench_concurrency - bench_vllm_serve_config = f'--dataset-name hf --dataset-path {bench_dataset} --num-prompts {num_prompts}' + bench_config_str = "mt_bench" + num_prompts = ( + bench_config["num_batches"] * bench_concurrency + ) + bench_vllm_serve_config = f"--dataset-name hf --dataset-path {bench_dataset} --num-prompts {num_prompts}" # noqa E501 - print(f"Number of prompts in {bench_dataset}: {num_prompts}") + print( + f"Number of prompts in {bench_dataset}: {num_prompts}" + ) # create dir if not exists - result_dir = f"{args.result_dir}/tp-{tp}_temp-{bench_temperature}_top_p-{bench_topp}/{bench_dataset}/{spec_method}/" + result_dir = f"{args.result_dir}/tp-{tp}_temp-{bench_temperature}_top_p-{bench_topp}/{bench_dataset}/{spec_method}/" # noqa E501 if not os.path.exists(result_dir): os.makedirs(result_dir) @@ -168,15 +188,27 @@ def run_benchmarks(args): """ if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--dry-run", action="store_true", help="Run in dry run mode. If set, commands will be printed but not executed.") + parser.add_argument( + "--dry-run", + action="store_true", + help="Run in dry run mode. If set, commands will be printed but not executed.", + ) parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--draft-dir", type=str, default=None) # parser.add_argument("--method-list", nargs='*', type=str, default=["vanilla", "eagle"]) parser.add_argument("--method", type=str, default="vanilla") - parser.add_argument("--num-speculative-tokens-list", nargs='*', type=int, default=[1, 3, 5]) - parser.add_argument("--batch-size-list", nargs='*', type=int, default=[1, 4, 16, 64, 256]) - parser.add_argument("--max-vllm-batch-size", type=int, help="Max vllm server batch size (max concurrency)") - parser.add_argument("--tp-list", nargs='*', type=int, default=[1]) + parser.add_argument( + "--num-speculative-tokens-list", nargs="*", type=int, default=[1, 3, 5] + ) + parser.add_argument( + "--batch-size-list", nargs="*", type=int, default=[1, 4, 16, 64, 256] + ) + parser.add_argument( + "--max-vllm-batch-size", + type=int, + help="Max vllm server batch size (max concurrency)", + ) + parser.add_argument("--tp-list", nargs="*", type=int, default=[1]) parser.add_argument("--result-dir", type=str, default="./log/dynamic_sd") parser.add_argument("--extra-log-arg", type=str, default="") args = parser.parse_args() @@ -186,12 +218,14 @@ def run_benchmarks(args): # assert 0 < len(args.method_list) <= 2 # if len(args.method_list) == 2: # assert "vanilla" in args.method_list, "If two methods are specified, one must be vanilla" - assert args.method in ["vanilla", "ngram", "eagle", "eagle3"], \ + assert args.method in ["vanilla", "ngram", "eagle", "eagle3"], ( "invalid method specified" + ) # assert 1 in args.batch_size_list, "batch_size must contain 1" # assert 1 in args.num_speculative_tokens_list, "num_speculative_tokens must contain 1" - assert args.max_vllm_batch_size == max(args.batch_size_list), \ + assert args.max_vllm_batch_size == max(args.batch_size_list), ( "max_vllm_batch_size must be equal to max of batch_size" + ) - run_benchmarks(args) \ No newline at end of file + run_benchmarks(args) diff --git a/vllm/v1/spec_decode/dynamic/online_profiling_server.py b/vllm/v1/spec_decode/dynamic/online_profiling_server.py index b1a1fff6a473..3d82da653fbc 100644 --- a/vllm/v1/spec_decode/dynamic/online_profiling_server.py +++ b/vllm/v1/spec_decode/dynamic/online_profiling_server.py @@ -1,24 +1,29 @@ -import os -import subprocess +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json -import time +import os import shutil import signal +import subprocess +import time """ Utility functions to manage the vLLM server for online profiling. Main functions are setup_server(), start_server(), and kill_server(). """ + def wait_for_server(port: int) -> bool: - timeout = 1200 # 20 mins + timeout = 1200 # 20 mins start_time = time.time() while time.time() - start_time < timeout: try: - subprocess.run(["curl", "-X", "POST", f"localhost:{port}/v1/completions"], check=True) + subprocess.run( + ["curl", "-X", "POST", f"localhost:{port}/v1/completions"], check=True + ) return True except subprocess.CalledProcessError: - time.sleep(10) # wait for 10 seconds before retrying + time.sleep(10) # wait for 10 seconds before retrying return False @@ -41,7 +46,7 @@ def kill_gpu_processes(port: int): if "python3" in line and filename not in line: pid = line.split()[1] pids_to_kill.append(pid) - + # Kill other processes for pid in pids_to_kill: subprocess.run(["kill", "-9", pid]) @@ -50,15 +55,20 @@ def kill_gpu_processes(port: int): # subprocess.run(["rm", "-rf", "~/.config/vllm"]) + def wait_for_gpu_memory_to_clear(): # Wait until all GPUs have memory usage < 1000 MB if shutil.which("nvidia-smi"): while True: # Get GPU memory usage for all GPUs - memory_usage = subprocess.check_output(["nvidia-smi", - "--query-gpu=memory.used", - "--format=csv,noheader,nounits"], - text=True) + memory_usage = subprocess.check_output( + [ + "nvidia-smi", + "--query-gpu=memory.used", + "--format=csv,noheader,nounits", + ], + text=True, + ) # Split the output into individual GPU memory usage values gpu_memory_usage = [int(x) for x in memory_usage.strip().split("\n")] # Check if any GPU has memory usage >= 1000 MB @@ -67,16 +77,15 @@ def wait_for_gpu_memory_to_clear(): time.sleep(1) elif shutil.which("amd-smi"): while True: - memory_usage = subprocess.check_output(["amd-smi", - "metric", - "-g", - "0"], - text=True) + memory_usage = subprocess.check_output( + ["amd-smi", "metric", "-g", "0"], text=True + ) used_vram = int(memory_usage.split("USED_VRAM")[1].split()[0]) if used_vram < 1000: break time.sleep(1) + def setup_server(): # install dependencies dependencies = ["lsof", "curl", "pgrep"] @@ -86,13 +95,14 @@ def setup_server(): subprocess.run(["apt-get", "install", "-y", dep]) -def start_server(port: int, - target_model_dir: str, - spec_config: dict | None, - tp: int, - max_vllm_bs: int, - dry_run: bool = False) -> subprocess.Popen | None: - +def start_server( + port: int, + target_model_dir: str, + spec_config: dict | None, + tp: int, + max_vllm_bs: int, + dry_run: bool = False, +) -> subprocess.Popen | None: # NOTE: no Prompt Caching, but enabled chunked prefill server_command = f"""VLLM_USE_V1=1 vllm serve {target_model_dir} \ --disable-log-requests --port {port} \ @@ -104,13 +114,17 @@ def start_server(port: int, if spec_config: speculative_config_json_serialized = json.dumps(spec_config).replace('"', '\\"') - server_command += f'--speculative_config "{speculative_config_json_serialized}" ' - + server_command += ( + f'--speculative_config "{speculative_config_json_serialized}" ' + ) + print(f"Server command: {server_command}") # start vllm server if not dry_run: - server_process = subprocess.Popen(server_command, shell=True, preexec_fn=os.setsid) + server_process = subprocess.Popen( + server_command, shell=True, preexec_fn=os.setsid + ) if wait_for_server(port): print("vllm server is up and running.") @@ -128,8 +142,8 @@ def start_server(port: int, # server_process.kill() # kill_gpu_processes(port) -def kill_server(port, server_process): +def kill_server(port, server_process): # REMOVE # print(f"Killing server on port {port}...") @@ -147,4 +161,4 @@ def kill_server(port, server_process): subprocess.run(["rm", "-rf", config_path]) # REMOVE - # print(f"Killed server on port {port} and cleaned up GPU processes.") \ No newline at end of file + # print(f"Killed server on port {port} and cleaned up GPU processes.") diff --git a/vllm/v1/spec_decode/dynamic/process_benchmark_results.py b/vllm/v1/spec_decode/dynamic/process_benchmark_results.py index 34d688571993..f0d2a11b685b 100644 --- a/vllm/v1/spec_decode/dynamic/process_benchmark_results.py +++ b/vllm/v1/spec_decode/dynamic/process_benchmark_results.py @@ -1,8 +1,11 @@ -import re -import os -import json +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse -from vllm.v1.spec_decode.dynamic.online_profiling_client import (NGRAM_FMT, EAGLE_FMT) +import json +import os +import re + +from vllm.v1.spec_decode.dynamic.online_profiling_client import EAGLE_FMT, NGRAM_FMT def reverse_fmt(fmt_str): @@ -12,6 +15,7 @@ def reverse_fmt(fmt_str): FMT = FMT.replace("{}", "(.+)") return FMT + NGRAM_FMT_REVERSE = reverse_fmt(NGRAM_FMT) EAGLE_FMT_REVERSE = reverse_fmt(EAGLE_FMT) @@ -31,37 +35,41 @@ def parse_itl(args): for method in ["vanilla", args.sd_method]: # find the names of all log files in this folder args.benchmark_path = os.path.join(args.benchmark_path_parent, method) - all_log_files = [f for f in os.listdir(args.benchmark_path) \ - if os.path.isfile(os.path.join(args.benchmark_path, f)) \ - and f.endswith(".txt")] - + all_log_files = [ + f + for f in os.listdir(args.benchmark_path) + if os.path.isfile(os.path.join(args.benchmark_path, f)) + and f.endswith(".txt") + ] + # parse the log files to get the config params for log_file in all_log_files: # find bs - bs = re.search(r'_bs-(\d+)', log_file).group(1) - + bs = re.search(r"_bs-(\d+)", log_file).group(1) + # find sd params spec_config_str = log_file.split("_")[0] if method == "vanilla": - k=0 + k = 0 elif method == "ngram": min, max, k = re.match(NGRAM_FMT_REVERSE, spec_config_str).groups() elif method == "eagle": k = re.match(EAGLE_FMT_REVERSE, spec_config_str).groups()[0] # read the log file to get the itl - with open(os.path.join(args.benchmark_path, log_file), "r") as f: + with open(os.path.join(args.benchmark_path, log_file)) as f: data = json.load(f) itl = data["median_itl_ms"] # add to batch_stats if int(bs) not in batch_stats: batch_stats[int(bs)] = {} - + batch_stats[int(bs)][int(k)] = itl return batch_stats + """ python3 vllm/v1/spec_decode/process_benchmark_results.py \ --sd-method eagle \ @@ -70,10 +78,15 @@ def parse_itl(args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--sd-method", type=str, default=None) - parser.add_argument("--benchmark-path-parent", type=str, default=None, help="Root folder which has the log files") + parser.add_argument( + "--benchmark-path-parent", + type=str, + default=None, + help="Root folder which has the log files", + ) args = parser.parse_args() assert args.sd_method in ["ngram", "eagle"], "Invalid method specified." batch_stats = parse_itl(args) - print(json.dumps(batch_stats, indent=4)) \ No newline at end of file + print(json.dumps(batch_stats, indent=4)) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 4c478b6d9c64..0969f8b70550 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -5,9 +5,7 @@ import numpy as np from numba import get_num_threads, jit, njit, prange, set_num_threads -from typing import Optional from vllm.config import VllmConfig -from vllm.model_executor.models.interfaces import SpeculativeDecodingProposer class NgramProposer: @@ -23,7 +21,7 @@ def __init__(self, vllm_config: VllmConfig): # Number of tokens follow the match. If there are less than k # tokens follow the match, we will return the maximum amount of # tokens until the end. - self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens + self.k = vllm_config.speculative_config.num_speculative_tokens # Maximum length of the model. self.max_model_len = vllm_config.model_config.max_model_len @@ -133,7 +131,7 @@ def batch_propose( def propose( self, - optimal_num_speculative_tokens: Optional[int], + optimal_num_speculative_tokens: int | None, sampled_token_ids: list[list[int]], req_ids: list[str], num_tokens_no_spec: np.ndarray, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 35b881a8090d..7774a8e3612a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -145,11 +145,11 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler +from vllm.v1.spec_decode.dynamic import DynamicSpeculativeDecodingManager from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer -from vllm.v1.spec_decode.dynamic import DynamicSpeculativeDecodingManager from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext @@ -460,9 +460,10 @@ def __init__( # setup Dynamic Speculative Decoding if self.speculative_config.dynamic_config: self.dynamic_sd_manager = DynamicSpeculativeDecodingManager( - self.speculative_config.dynamic_config, + self.speculative_config.dynamic_config, self.vllm_config.scheduler_config.max_num_seqs, - self.vllm_config.speculative_config.num_speculative_tokens) + self.vllm_config.speculative_config.num_speculative_tokens, + ) else: self.dynamic_sd_manager = None @@ -3598,19 +3599,21 @@ def propose_draft_token_ids( spec_decode_metadata: SpecDecodeMetadata | None, common_attn_metadata: CommonAttentionMetadata, ) -> list[list[int]] | torch.Tensor: - optimal_num_speculative_tokens = None if self.dynamic_sd_manager: batch_size = self.input_batch.num_reqs - optimal_num_speculative_tokens = self.dynamic_sd_manager.\ - get_optimal_num_speculative_tokens( + optimal_num_speculative_tokens = ( + self.dynamic_sd_manager.get_optimal_num_speculative_tokens( self.input_batch.num_reqs ) - + ) + # REMOVE - print(f"Batch size: {batch_size}, " - f"Optimal num speculative tokens: {optimal_num_speculative_tokens}") - + print( + f"Batch size: {batch_size}, " + f"Optimal num speculative tokens: {optimal_num_speculative_tokens}" + ) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens spec_config = self.speculative_config assert spec_config is not None From 5e63d5110f0766df4fb9da6342c85edea26855fe Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Sun, 4 Jan 2026 05:20:41 +0000 Subject: [PATCH 04/39] start stiching Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- examples/offline_inference/spec_decode.py | 6 +- vllm/config/speculative.py | 6 +- .../v1/spec_decode/dynamic/generate_config.py | 64 +++++ .../dynamic/online_profiling_client.py | 231 ----------------- .../spec_decode/dynamic/profiling_client.py | 237 ++++++++++++++++++ ...rofiling_server.py => profiling_server.py} | 0 6 files changed, 308 insertions(+), 236 deletions(-) create mode 100644 vllm/v1/spec_decode/dynamic/generate_config.py delete mode 100644 vllm/v1/spec_decode/dynamic/online_profiling_client.py create mode 100644 vllm/v1/spec_decode/dynamic/profiling_client.py rename vllm/v1/spec_decode/dynamic/{online_profiling_server.py => profiling_server.py} (100%) diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 29b2e95d262f..8f3b8670e2cb 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -192,16 +192,18 @@ def main(args): print("-" * 50) # print acceptance at each token position + acceptance_rate_per_pos = [] for i in range(len(acceptance_counts)): acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0 print(f"acceptance at token {i}: {acceptance_rate:.2f}") + acceptance_rate_per_pos.append(acceptance_rate) - return acceptance_length + return acceptance_length, acceptance_rate_per_pos if __name__ == "__main__": args = parse_args() - acceptance_length = main(args) + acceptance_length, acceptance_rate_per_pos = main(args) if args.test: # takes ~30s to run on 1xH100 diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 79f7e1435b6f..ca047712ccff 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -4,7 +4,7 @@ import ast from typing import TYPE_CHECKING, Any, Literal, get_args -from pydantic import Field, SkipValidation, model_validator +from pydantic import Field, SkipValidation, model_validator, BaseModel from pydantic.dataclasses import dataclass from typing_extensions import Self @@ -49,8 +49,8 @@ ] -@dataclass -class DynamicSpeculativeConfig: +# @dataclass +class DynamicSpeculativeConfig(BaseModel): # """A mapping from batch size to optimal number of drafts to use for that # batch size. This is used to dynamically adjust the number of drafts used # based on the current batch size.""" diff --git a/vllm/v1/spec_decode/dynamic/generate_config.py b/vllm/v1/spec_decode/dynamic/generate_config.py new file mode 100644 index 000000000000..369974b63cc0 --- /dev/null +++ b/vllm/v1/spec_decode/dynamic/generate_config.py @@ -0,0 +1,64 @@ +from vllm.config.speculative import DynamicSpeculativeConfig +from vllm.v1.spec_decode.dynamic.process_benchmark_results import parse_itl +from examples.offline_inference.spec_decode import main as spec_decode_main +from vllm.v1.spec_decode.dynamic.process_benchmark_results import parse_itl +from vllm.v1.spec_decode.dynamic.profiling_client import run_benchmarks +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.benchmarks.datasets import add_dataset_parser, get_samples + +def main(): + parser = FlexibleArgumentParser() + parser.add_argument("--model-dir", type=str, default=None) + parser.add_argument("--draft-dir", type=str, default=None) + parser.add_argument("--method", type=str, default="vanilla") + parser.add_argument( + "--num-speculative-tokens-list", nargs="*", type=int, default=[1, 3, 5] + ) + parser.add_argument( + "--batch-size-list", nargs="*", type=int, default=[1, 4, 16, 64, 256] + ) + parser.add_argument( + "--max-vllm-batch-size", + type=int, + help="Max vllm server batch size (max concurrency)", + ) + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--result-dir", type=str, default="./log/dynamic_sd") + parser.add_argument("--extra-log-arg", type=str, default="") + + + args = parser.parse_args() + + # Step 1: get acceptance_rate_per_pos + acceptance_length, acceptance_rate_per_pos = spec_decode_main(args) + + # Step 2: generate benchmark data for vanilla and specified method + for method in ["vanilla", args.method]: + run_benchmarks( + dry_run = False, + model_dir = args.model_dir, + draft_dir = args.draft_dir, + method = method, + num_speculative_tokens_list = args.num_speculative_tokens_list, + batch_size_list = args.batch_size_list, + max_vllm_batch_size = args.max_vllm_batch_size, + tp = args.tp, + result_dir = args.result_dir, + extra_log_arg = args.extra_log_arg + ) + + # Step 3: parse batch_stats from benchmark data + batch_stats = parse_itl(args.result_dir) + + # Step 4: create DynamicSpeculativeConfig + dynamic_config = DynamicSpeculativeConfig( + is_online=False, + max_num_speculative_tokens=len(acceptance_rate_per_pos) + 1, + acceptance_rate_per_pos=acceptance_rate_per_pos, + batch_stats=batch_stats, + ) + + # Step 5: save dynamic_config to a json file + import json + with open(f"{args.result_dir}/dynamic_speculative_config.json", "w") as f: + dynamic_config.model_dump_json(f, indent=4) \ No newline at end of file diff --git a/vllm/v1/spec_decode/dynamic/online_profiling_client.py b/vllm/v1/spec_decode/dynamic/online_profiling_client.py deleted file mode 100644 index 4a88723de797..000000000000 --- a/vllm/v1/spec_decode/dynamic/online_profiling_client.py +++ /dev/null @@ -1,231 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import argparse -import os -import subprocess -from dataclasses import dataclass - -from vllm.v1.spec_decode.dynamic.online_profiling_server import ( - kill_server, - setup_server, - start_server, -) - - -@dataclass -class Dataset: - name: str - config: list - - -NGRAM_FMT = "min-{min}-max-{max}-k-{k}" -EAGLE_FMT = "k-{k}" - - -def run_command(command): - try: - result = subprocess.run( - f"bash -c '{command}'", - shell=True, - check=True, - capture_output=True, - text=True, - ) - print("Output:") - print(result.stdout) - except subprocess.CalledProcessError as e: - print("Error:") - print(e.stderr) - - -def run_benchmarks(args): - # setup server - setup_server() - - port = 9001 - all_sampling_profile = [ - {"temperature": 0, "topp": 1}, # greedy - ] - - # `num_batches` decides how many batches are sent for each concurrency. - # E.g., num_batches=20 and concurrency=4 means total 80 prompts are sent - # such that we send 20 batches of 4 prompts each. This ensures a consistent - # number of batches across different concurrencies. For e.g., if total - # samples is 80 then concurrency 1 will send 80 batches while concurrency 64 - # will send 2 batches only. - MTBENCH_CONFIG = [{"num_batches": 20}] - - all_bench_dataset = [ - Dataset(name="philschmid/mt-bench", config=MTBENCH_CONFIG), - ] - - assert all(len(ds.config) > 0 for ds in all_bench_dataset), ( - "Each dataset must have at least one config" - ) - - all_ngram_params = [ - {"min": 2, "max": 5, "k": k} for k in args.num_speculative_tokens_list - ] - all_eagle_params = args.num_speculative_tokens_list - - # ablation - num_exp_run = 0 - for tp in args.tp_list: - spec_method = args.method - # collate all spec configs to run for a given method - all_spec_config = [] - if spec_method == "ngram": - for ngram_params in all_ngram_params: - all_spec_config.append( - { - "method": "ngram", - "num_speculative_tokens": ngram_params["k"], - "prompt_lookup_max": ngram_params["max"], - "prompt_lookup_min": ngram_params["min"], - } - ) - elif spec_method == "eagle": - for eagle_k in all_eagle_params: - all_spec_config.append( - { - "method": "eagle", - "model": args.draft_dir, - "num_speculative_tokens": eagle_k, - "draft_tensor_parallel_size": tp, - } - ) - else: - # vanilla case - all_spec_config.append(None) - - for spec_config in all_spec_config: - # start server - server_process = start_server( - port=port, - target_model_dir=args.model_dir, - spec_config=spec_config, - tp=tp, - max_vllm_bs=args.max_vllm_batch_size, - dry_run=args.dry_run, - ) - - # start client - for bench_concurrency in args.batch_size_list: - for bench_dataset_object in all_bench_dataset: - bench_dataset = bench_dataset_object.name - for bench_config in bench_dataset_object.config: - for sampling_profile in all_sampling_profile: - bench_temperature = sampling_profile["temperature"] - bench_topp = sampling_profile["topp"] - - spec_config_str = "vanilla" - if spec_method == "ngram": - spec_config_str = NGRAM_FMT.format( - min=spec_config["prompt_lookup_min"], - max=spec_config["prompt_lookup_max"], - k=spec_config["num_speculative_tokens"], - ) - elif spec_method == "eagle": - spec_config_str = EAGLE_FMT.format( - k=spec_config["num_speculative_tokens"] - ) - - # dataset specific config - if "philschmid/mt-bench" in bench_dataset: - bench_config_str = "mt_bench" - num_prompts = ( - bench_config["num_batches"] * bench_concurrency - ) - bench_vllm_serve_config = f"--dataset-name hf --dataset-path {bench_dataset} --num-prompts {num_prompts}" # noqa E501 - - print( - f"Number of prompts in {bench_dataset}: {num_prompts}" - ) - - # create dir if not exists - result_dir = f"{args.result_dir}/tp-{tp}_temp-{bench_temperature}_top_p-{bench_topp}/{bench_dataset}/{spec_method}/" # noqa E501 - if not os.path.exists(result_dir): - os.makedirs(result_dir) - - cmd = f'''time vllm bench serve --port {port} --save-result --save-detailed \ - --model {args.model_dir} \ - --backend openai-chat \ - --endpoint /v1/chat/completions \ - {bench_vllm_serve_config} \ - --max-concurrency {bench_concurrency} \ - --temperature={bench_temperature} \ - --top-p={bench_topp} \ - --result-dir "{result_dir}" \ - --result-filename "{spec_config_str}_{bench_config_str}_bs-{bench_concurrency}_{args.extra_log_arg}.txt"''' - - print(cmd) - num_exp_run += 1 - - if not args.dry_run: - run_command(cmd) - - # server teardown: kill server and any gpu processes - kill_server(port, server_process) - - print(f"Total number of experiments run: {num_exp_run}") - - -""" -# eagle -time python3 vllm/v1/spec_decode/online_profiling_client.py \ - --batch-size-list 1 4 16 64 256 \ - --num-speculative-tokens-list 1 3 5 \ - --max-vllm-batch-size 256 \ - --method eagle \ - --model-dir meta-llama/Llama-3.1-8B-Instruct \ - --draft-dir yuhuili/EAGLE-LLaMA3.1-Instruct-8B - -# vanilla -time python3 vllm/v1/spec_decode/online_profiling_client.py \ - --batch-size-list 1 4 16 64 256 \ - --max-vllm-batch-size 256 \ - --method vanilla -""" -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--dry-run", - action="store_true", - help="Run in dry run mode. If set, commands will be printed but not executed.", - ) - parser.add_argument("--model-dir", type=str, default=None) - parser.add_argument("--draft-dir", type=str, default=None) - # parser.add_argument("--method-list", nargs='*', type=str, default=["vanilla", "eagle"]) - parser.add_argument("--method", type=str, default="vanilla") - parser.add_argument( - "--num-speculative-tokens-list", nargs="*", type=int, default=[1, 3, 5] - ) - parser.add_argument( - "--batch-size-list", nargs="*", type=int, default=[1, 4, 16, 64, 256] - ) - parser.add_argument( - "--max-vllm-batch-size", - type=int, - help="Max vllm server batch size (max concurrency)", - ) - parser.add_argument("--tp-list", nargs="*", type=int, default=[1]) - parser.add_argument("--result-dir", type=str, default="./log/dynamic_sd") - parser.add_argument("--extra-log-arg", type=str, default="") - args = parser.parse_args() - - # assert all([method in ["vanilla", "ngram", "eagle", "eagle3"] for method in args.method_list]), \ - # "invalid method in method_list" - # assert 0 < len(args.method_list) <= 2 - # if len(args.method_list) == 2: - # assert "vanilla" in args.method_list, "If two methods are specified, one must be vanilla" - assert args.method in ["vanilla", "ngram", "eagle", "eagle3"], ( - "invalid method specified" - ) - - # assert 1 in args.batch_size_list, "batch_size must contain 1" - # assert 1 in args.num_speculative_tokens_list, "num_speculative_tokens must contain 1" - assert args.max_vllm_batch_size == max(args.batch_size_list), ( - "max_vllm_batch_size must be equal to max of batch_size" - ) - - run_benchmarks(args) diff --git a/vllm/v1/spec_decode/dynamic/profiling_client.py b/vllm/v1/spec_decode/dynamic/profiling_client.py new file mode 100644 index 000000000000..ff0580c226d3 --- /dev/null +++ b/vllm/v1/spec_decode/dynamic/profiling_client.py @@ -0,0 +1,237 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import os +import subprocess +from dataclasses import dataclass + +from vllm.v1.spec_decode.dynamic.profiling_server import ( + kill_server, + setup_server, + start_server, +) + + +@dataclass +class Dataset: + name: str + config: list + + +NGRAM_FMT = "min-{min}-max-{max}-k-{k}" +EAGLE_FMT = "k-{k}" + + +def run_command(command): + try: + result = subprocess.run( + f"bash -c '{command}'", + shell=True, + check=True, + capture_output=True, + text=True, + ) + print("Output:") + print(result.stdout) + except subprocess.CalledProcessError as e: + print("Error:") + print(e.stderr) + + +def run_benchmarks(dry_run, model_dir, draft_dir, method, + num_speculative_tokens_list, batch_size_list, + max_vllm_batch_size, tp, result_dir, extra_log_arg): + + assert method in ["vanilla", "ngram", "eagle", "eagle3"], ( + "invalid method specified" + ) + + assert max_vllm_batch_size == max(batch_size_list), ( + "max_vllm_batch_size must be equal to max of batch_size" + ) + + # setup server + setup_server() + + port = 9001 + all_sampling_profile = [ + {"temperature": 0, "topp": 1}, # greedy + ] + + # `num_batches` decides how many batches are sent for each concurrency. + # E.g., num_batches=20 and concurrency=4 means total 80 prompts are sent + # such that we send 20 batches of 4 prompts each. This ensures a consistent + # number of batches across different concurrencies. For e.g., if total + # samples is 80 then concurrency 1 will send 80 batches while concurrency 64 + # will send 2 batches only. + MTBENCH_CONFIG = [{"num_batches": 20}] + + all_bench_dataset = [ + Dataset(name="philschmid/mt-bench", config=MTBENCH_CONFIG), + ] + + assert all(len(ds.config) > 0 for ds in all_bench_dataset), ( + "Each dataset must have at least one config" + ) + + all_ngram_params = [ + {"min": 2, "max": 5, "k": k} for k in num_speculative_tokens_list + ] + all_eagle_params = num_speculative_tokens_list + + # ablation + num_exp_run = 0 + spec_method = method + # collate all spec configs to run for a given method + all_spec_config = [] + if spec_method == "ngram": + for ngram_params in all_ngram_params: + all_spec_config.append( + { + "method": "ngram", + "num_speculative_tokens": ngram_params["k"], + "prompt_lookup_max": ngram_params["max"], + "prompt_lookup_min": ngram_params["min"], + } + ) + elif spec_method == "eagle": + for eagle_k in all_eagle_params: + all_spec_config.append( + { + "method": "eagle", + "model": draft_dir, + "num_speculative_tokens": eagle_k, + "draft_tensor_parallel_size": tp, + } + ) + else: + # vanilla case + all_spec_config.append(None) + + for spec_config in all_spec_config: + # start server + server_process = start_server( + port=port, + target_model_dir=model_dir, + spec_config=spec_config, + tp=tp, + max_vllm_bs=max_vllm_batch_size, + dry_run=dry_run, + ) + + # start client + for bench_concurrency in batch_size_list: + for bench_dataset_object in all_bench_dataset: + bench_dataset = bench_dataset_object.name + for bench_config in bench_dataset_object.config: + for sampling_profile in all_sampling_profile: + bench_temperature = sampling_profile["temperature"] + bench_topp = sampling_profile["topp"] + + spec_config_str = "vanilla" + if spec_method == "ngram": + spec_config_str = NGRAM_FMT.format( + min=spec_config["prompt_lookup_min"], + max=spec_config["prompt_lookup_max"], + k=spec_config["num_speculative_tokens"], + ) + elif spec_method == "eagle": + spec_config_str = EAGLE_FMT.format( + k=spec_config["num_speculative_tokens"] + ) + + # dataset specific config + if "philschmid/mt-bench" in bench_dataset: + bench_config_str = "mt_bench" + num_prompts = ( + bench_config["num_batches"] * bench_concurrency + ) + bench_vllm_serve_config = f"--dataset-name hf --dataset-path {bench_dataset} --num-prompts {num_prompts}" # noqa E501 + + print( + f"Number of prompts in {bench_dataset}: {num_prompts}" + ) + + # create dir if not exists + # TODO: make the path shared with generate_config.py + result_dir = f"{result_dir}/tp-{tp}_temp-{bench_temperature}_top_p-{bench_topp}/{bench_dataset}/{spec_method}/" # noqa E501 + if not os.path.exists(result_dir): + os.makedirs(result_dir) + + cmd = f'''time vllm bench serve --port {port} --save-result --save-detailed \ + --model {model_dir} \ + --backend openai-chat \ + --endpoint /v1/chat/completions \ + {bench_vllm_serve_config} \ + --max-concurrency {bench_concurrency} \ + --temperature={bench_temperature} \ + --top-p={bench_topp} \ + --result-dir "{result_dir}" \ + --result-filename "{spec_config_str}_{bench_config_str}_bs-{bench_concurrency}_{extra_log_arg}.txt"''' + + print(cmd) + num_exp_run += 1 + + if not dry_run: + run_command(cmd) + + # server teardown: kill server and any gpu processes + kill_server(port, server_process) + + print(f"Total number of experiments run: {num_exp_run}") + + +""" +# eagle +time python3 vllm/v1/spec_decode/online_profiling_client.py \ + --batch-size-list 1 4 16 64 256 \ + --num-speculative-tokens-list 1 3 5 \ + --max-vllm-batch-size 256 \ + --method eagle \ + --model-dir meta-llama/Llama-3.1-8B-Instruct \ + --draft-dir yuhuili/EAGLE-LLaMA3.1-Instruct-8B + +# vanilla +time python3 vllm/v1/spec_decode/online_profiling_client.py \ + --batch-size-list 1 4 16 64 256 \ + --max-vllm-batch-size 256 \ + --method vanilla +""" +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dry-run", + action="store_true", + help="Run in dry run mode. If set, commands will be printed but not executed.", + ) + parser.add_argument("--model-dir", type=str, default=None) + parser.add_argument("--draft-dir", type=str, default=None) + # parser.add_argument("--method-list", nargs='*', type=str, default=["vanilla", "eagle"]) + parser.add_argument("--method", type=str, default="vanilla") + parser.add_argument( + "--num-speculative-tokens-list", nargs="*", type=int, default=[1, 3, 5] + ) + parser.add_argument( + "--batch-size-list", nargs="*", type=int, default=[1, 4, 16, 64, 256] + ) + parser.add_argument( + "--max-vllm-batch-size", + type=int, + help="Max vllm server batch size (max concurrency)", + ) + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--result-dir", type=str, default="./log/dynamic_sd") + parser.add_argument("--extra-log-arg", type=str, default="") + args = parser.parse_args() + + run_benchmarks( + dry_run = args.dry_run, + model_dir = args.model_dir, + draft_dir = args.draft_dir, + method = args.method, + num_speculative_tokens_list = args.num_speculative_tokens_list, + batch_size_list = args.batch_size_list, + max_vllm_batch_size = args.max_vllm_batch_size, + tp = args.tp, + result_dir = args.result_dir, + extra_log_arg = args.extra_log_arg) \ No newline at end of file diff --git a/vllm/v1/spec_decode/dynamic/online_profiling_server.py b/vllm/v1/spec_decode/dynamic/profiling_server.py similarity index 100% rename from vllm/v1/spec_decode/dynamic/online_profiling_server.py rename to vllm/v1/spec_decode/dynamic/profiling_server.py From b683f26c5e33215819b849cb9b7c0fad38c3867b Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Mon, 5 Jan 2026 16:49:21 +0000 Subject: [PATCH 05/39] pipeline works and offline script moved Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- examples/offline_inference/spec_decode.py | 236 +---------------- examples/offline_inference/spec_decode_bkp.py | 236 +++++++++++++++++ .../v1/spec_decode/dynamic/generate_config.py | 106 +++++++- .../dynamic/{dynamic.py => manager.py} | 36 ++- .../dynamic/process_benchmark_results.py | 25 +- .../spec_decode/dynamic/profiling_client.py | 185 +++++++------- vllm/v1/spec_decode/eagle.py | 2 +- vllm/v1/spec_decode/offline.py | 240 ++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 2 +- 9 files changed, 706 insertions(+), 362 deletions(-) create mode 100644 examples/offline_inference/spec_decode_bkp.py rename vllm/v1/spec_decode/dynamic/{dynamic.py => manager.py} (87%) create mode 100644 vllm/v1/spec_decode/offline.py diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 8f3b8670e2cb..df64f75e47b0 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -1,236 +1,4 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from transformers import AutoTokenizer - -from vllm import LLM, SamplingParams -from vllm.benchmarks.datasets import add_dataset_parser, get_samples -from vllm.inputs import TokensPrompt -from vllm.v1.metrics.reader import Counter, Vector - -try: - from vllm.utils.argparse_utils import FlexibleArgumentParser -except ImportError: - from argparse import ArgumentParser as FlexibleArgumentParser - - -QUESTION = "What is the content of each image?" -IMAGE_URLS = [ - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/duck.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/lion.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/flycatcher.jpeg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/somefish.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/starfish.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/snail.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/thistle.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/husky.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/orangetabbycat.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/guineapig.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/rabbit.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/horsepony.jpg", -] - - -def get_custom_mm_prompts(num_prompts): - prompts = [] - for url in IMAGE_URLS: - prompts.append( - [ - {"type": "image_url", "image_url": {"url": url}}, - {"type": "text", "text": QUESTION}, - ] - ) - if num_prompts > len(IMAGE_URLS): - prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1) - - return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]] - - -def parse_args(): - parser = FlexibleArgumentParser() - add_dataset_parser(parser) - parser.add_argument("--test", action="store_true") - parser.add_argument( - "--method", - type=str, - default="eagle", - choices=["ngram", "eagle", "eagle3", "mtp"], - ) - parser.add_argument("--num-spec-tokens", type=int, default=2) - parser.add_argument("--prompt-lookup-max", type=int, default=5) - parser.add_argument("--prompt-lookup-min", type=int, default=2) - parser.add_argument("--tp", type=int, default=1) - parser.add_argument("--enforce-eager", action="store_true") - parser.add_argument("--enable-chunked-prefill", action="store_true") - parser.add_argument("--max-model-len", type=int, default=16384) - parser.add_argument("--temp", type=float, default=0) - parser.add_argument("--top-p", type=float, default=1.0) - parser.add_argument("--top-k", type=int, default=-1) - parser.add_argument("--print-output", action="store_true") - parser.add_argument("--output-len", type=int, default=256) - parser.add_argument("--model-dir", type=str, default=None) - parser.add_argument("--eagle-dir", type=str, default=None) - parser.add_argument("--custom-mm-prompts", action="store_true") - return parser.parse_args() - - -def main(args): - args.endpoint_type = "openai-chat" - - model_dir = args.model_dir - if args.model_dir is None: - if args.custom_mm_prompts: - raise ValueError( - "custom_mm_prompts requires mm based models" - "default llama3.1-8b-instruct is not mm based" - "please specify model_dir to give a mm based model" - ) - model_dir = "meta-llama/Llama-3.1-8B-Instruct" - tokenizer = AutoTokenizer.from_pretrained(model_dir) - args.custom_skip_chat_template = True - - if not args.custom_mm_prompts: - prompts = get_samples(args, tokenizer) - # add_special_tokens is False to avoid adding bos twice - # when using chat templates - prompt_ids = [ - tokenizer.encode(prompt.prompt, add_special_tokens=False) - for prompt in prompts - ] - else: - prompts = get_custom_mm_prompts(args.num_prompts) - - if args.method == "eagle" or args.method == "eagle3": - eagle_dir = args.eagle_dir - if args.method == "eagle" and eagle_dir is None: - eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" - - elif args.method == "eagle3" and eagle_dir is None: - eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" - speculative_config = { - "method": args.method, - "model": eagle_dir, - "num_speculative_tokens": args.num_spec_tokens, - } - elif args.method == "ngram": - speculative_config = { - "method": "ngram", - "num_speculative_tokens": args.num_spec_tokens, - "prompt_lookup_max": args.prompt_lookup_max, - "prompt_lookup_min": args.prompt_lookup_min, - } - elif args.method == "mtp": - speculative_config = { - "method": "mtp", - "num_speculative_tokens": args.num_spec_tokens, - } - else: - raise ValueError(f"unknown method: {args.method}") - - llm = LLM( - model=model_dir, - trust_remote_code=True, - tensor_parallel_size=args.tp, - enable_chunked_prefill=args.enable_chunked_prefill, - enforce_eager=args.enforce_eager, - gpu_memory_utilization=0.9, - speculative_config=speculative_config, - disable_log_stats=False, - max_model_len=args.max_model_len, - limit_mm_per_prompt={"image": 5}, - disable_chunked_mm_input=True, - ) - - sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) - if not args.custom_mm_prompts: - outputs = llm.generate( - [TokensPrompt(prompt_token_ids=x) for x in prompt_ids], - sampling_params=sampling_params, - ) - else: - outputs = llm.chat(prompts, sampling_params=sampling_params) - - # print the generated text - if args.print_output: - for output in outputs: - print("-" * 50) - print(f"prompt: {output.prompt}") - print(f"generated text: {output.outputs[0].text}") - print("-" * 50) - - metrics = llm.get_metrics() - - total_num_output_tokens = sum( - len(output.outputs[0].token_ids) for output in outputs - ) - num_drafts = 0 - num_draft_tokens = 0 - num_accepted_tokens = 0 - acceptance_counts = [0] * args.num_spec_tokens - for metric in metrics: - if metric.name == "vllm:spec_decode_num_drafts": - assert isinstance(metric, Counter) - num_drafts += metric.value - elif metric.name == "vllm:spec_decode_num_draft_tokens": - assert isinstance(metric, Counter) - num_draft_tokens += metric.value - elif metric.name == "vllm:spec_decode_num_accepted_tokens": - assert isinstance(metric, Counter) - num_accepted_tokens += metric.value - elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": - assert isinstance(metric, Vector) - for pos in range(len(metric.values)): - acceptance_counts[pos] += metric.values[pos] - - print("-" * 50) - print(f"total_num_output_tokens: {total_num_output_tokens}") - print(f"num_drafts: {num_drafts}") - print(f"num_draft_tokens: {num_draft_tokens}") - print(f"num_accepted_tokens: {num_accepted_tokens}") - acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1 - print(f"mean acceptance length: {acceptance_length:.2f}") - print("-" * 50) - - # print acceptance at each token position - acceptance_rate_per_pos = [] - for i in range(len(acceptance_counts)): - acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0 - print(f"acceptance at token {i}: {acceptance_rate:.2f}") - acceptance_rate_per_pos.append(acceptance_rate) - - return acceptance_length, acceptance_rate_per_pos - +from vllm.v1.spec_decode.offline import entrypoint as spec_decode_main if __name__ == "__main__": - args = parse_args() - acceptance_length, acceptance_rate_per_pos = main(args) - - if args.test: - # takes ~30s to run on 1xH100 - assert args.method in ["eagle", "eagle3"] - assert args.tp == 1 - assert args.num_spec_tokens == 3 - assert args.dataset_name == "hf" - assert args.dataset_path == "philschmid/mt-bench" - assert args.num_prompts == 80 - assert args.temp == 0 - assert args.top_p == 1.0 - assert args.top_k == -1 - assert args.enable_chunked_prefill - - # check acceptance length is within 2% of expected value - rtol = 0.02 - expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811 - - assert ( - acceptance_length <= (1 + rtol) * expected_acceptance_length - and acceptance_length >= (1 - rtol) * expected_acceptance_length - ), ( - f"acceptance_length {acceptance_length} is not " - f"within {rtol * 100}% of {expected_acceptance_length}" - ) - - print( - f"Test passed! Expected AL: " - f"{expected_acceptance_length}, got {acceptance_length}" - ) + spec_decode_main() \ No newline at end of file diff --git a/examples/offline_inference/spec_decode_bkp.py b/examples/offline_inference/spec_decode_bkp.py new file mode 100644 index 000000000000..8f3b8670e2cb --- /dev/null +++ b/examples/offline_inference/spec_decode_bkp.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from transformers import AutoTokenizer + +from vllm import LLM, SamplingParams +from vllm.benchmarks.datasets import add_dataset_parser, get_samples +from vllm.inputs import TokensPrompt +from vllm.v1.metrics.reader import Counter, Vector + +try: + from vllm.utils.argparse_utils import FlexibleArgumentParser +except ImportError: + from argparse import ArgumentParser as FlexibleArgumentParser + + +QUESTION = "What is the content of each image?" +IMAGE_URLS = [ + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/duck.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/lion.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/flycatcher.jpeg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/somefish.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/starfish.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/snail.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/thistle.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/husky.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/orangetabbycat.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/guineapig.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/rabbit.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/horsepony.jpg", +] + + +def get_custom_mm_prompts(num_prompts): + prompts = [] + for url in IMAGE_URLS: + prompts.append( + [ + {"type": "image_url", "image_url": {"url": url}}, + {"type": "text", "text": QUESTION}, + ] + ) + if num_prompts > len(IMAGE_URLS): + prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1) + + return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]] + + +def parse_args(): + parser = FlexibleArgumentParser() + add_dataset_parser(parser) + parser.add_argument("--test", action="store_true") + parser.add_argument( + "--method", + type=str, + default="eagle", + choices=["ngram", "eagle", "eagle3", "mtp"], + ) + parser.add_argument("--num-spec-tokens", type=int, default=2) + parser.add_argument("--prompt-lookup-max", type=int, default=5) + parser.add_argument("--prompt-lookup-min", type=int, default=2) + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--enforce-eager", action="store_true") + parser.add_argument("--enable-chunked-prefill", action="store_true") + parser.add_argument("--max-model-len", type=int, default=16384) + parser.add_argument("--temp", type=float, default=0) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--top-k", type=int, default=-1) + parser.add_argument("--print-output", action="store_true") + parser.add_argument("--output-len", type=int, default=256) + parser.add_argument("--model-dir", type=str, default=None) + parser.add_argument("--eagle-dir", type=str, default=None) + parser.add_argument("--custom-mm-prompts", action="store_true") + return parser.parse_args() + + +def main(args): + args.endpoint_type = "openai-chat" + + model_dir = args.model_dir + if args.model_dir is None: + if args.custom_mm_prompts: + raise ValueError( + "custom_mm_prompts requires mm based models" + "default llama3.1-8b-instruct is not mm based" + "please specify model_dir to give a mm based model" + ) + model_dir = "meta-llama/Llama-3.1-8B-Instruct" + tokenizer = AutoTokenizer.from_pretrained(model_dir) + args.custom_skip_chat_template = True + + if not args.custom_mm_prompts: + prompts = get_samples(args, tokenizer) + # add_special_tokens is False to avoid adding bos twice + # when using chat templates + prompt_ids = [ + tokenizer.encode(prompt.prompt, add_special_tokens=False) + for prompt in prompts + ] + else: + prompts = get_custom_mm_prompts(args.num_prompts) + + if args.method == "eagle" or args.method == "eagle3": + eagle_dir = args.eagle_dir + if args.method == "eagle" and eagle_dir is None: + eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" + + elif args.method == "eagle3" and eagle_dir is None: + eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" + speculative_config = { + "method": args.method, + "model": eagle_dir, + "num_speculative_tokens": args.num_spec_tokens, + } + elif args.method == "ngram": + speculative_config = { + "method": "ngram", + "num_speculative_tokens": args.num_spec_tokens, + "prompt_lookup_max": args.prompt_lookup_max, + "prompt_lookup_min": args.prompt_lookup_min, + } + elif args.method == "mtp": + speculative_config = { + "method": "mtp", + "num_speculative_tokens": args.num_spec_tokens, + } + else: + raise ValueError(f"unknown method: {args.method}") + + llm = LLM( + model=model_dir, + trust_remote_code=True, + tensor_parallel_size=args.tp, + enable_chunked_prefill=args.enable_chunked_prefill, + enforce_eager=args.enforce_eager, + gpu_memory_utilization=0.9, + speculative_config=speculative_config, + disable_log_stats=False, + max_model_len=args.max_model_len, + limit_mm_per_prompt={"image": 5}, + disable_chunked_mm_input=True, + ) + + sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) + if not args.custom_mm_prompts: + outputs = llm.generate( + [TokensPrompt(prompt_token_ids=x) for x in prompt_ids], + sampling_params=sampling_params, + ) + else: + outputs = llm.chat(prompts, sampling_params=sampling_params) + + # print the generated text + if args.print_output: + for output in outputs: + print("-" * 50) + print(f"prompt: {output.prompt}") + print(f"generated text: {output.outputs[0].text}") + print("-" * 50) + + metrics = llm.get_metrics() + + total_num_output_tokens = sum( + len(output.outputs[0].token_ids) for output in outputs + ) + num_drafts = 0 + num_draft_tokens = 0 + num_accepted_tokens = 0 + acceptance_counts = [0] * args.num_spec_tokens + for metric in metrics: + if metric.name == "vllm:spec_decode_num_drafts": + assert isinstance(metric, Counter) + num_drafts += metric.value + elif metric.name == "vllm:spec_decode_num_draft_tokens": + assert isinstance(metric, Counter) + num_draft_tokens += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens": + assert isinstance(metric, Counter) + num_accepted_tokens += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": + assert isinstance(metric, Vector) + for pos in range(len(metric.values)): + acceptance_counts[pos] += metric.values[pos] + + print("-" * 50) + print(f"total_num_output_tokens: {total_num_output_tokens}") + print(f"num_drafts: {num_drafts}") + print(f"num_draft_tokens: {num_draft_tokens}") + print(f"num_accepted_tokens: {num_accepted_tokens}") + acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1 + print(f"mean acceptance length: {acceptance_length:.2f}") + print("-" * 50) + + # print acceptance at each token position + acceptance_rate_per_pos = [] + for i in range(len(acceptance_counts)): + acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0 + print(f"acceptance at token {i}: {acceptance_rate:.2f}") + acceptance_rate_per_pos.append(acceptance_rate) + + return acceptance_length, acceptance_rate_per_pos + + +if __name__ == "__main__": + args = parse_args() + acceptance_length, acceptance_rate_per_pos = main(args) + + if args.test: + # takes ~30s to run on 1xH100 + assert args.method in ["eagle", "eagle3"] + assert args.tp == 1 + assert args.num_spec_tokens == 3 + assert args.dataset_name == "hf" + assert args.dataset_path == "philschmid/mt-bench" + assert args.num_prompts == 80 + assert args.temp == 0 + assert args.top_p == 1.0 + assert args.top_k == -1 + assert args.enable_chunked_prefill + + # check acceptance length is within 2% of expected value + rtol = 0.02 + expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811 + + assert ( + acceptance_length <= (1 + rtol) * expected_acceptance_length + and acceptance_length >= (1 - rtol) * expected_acceptance_length + ), ( + f"acceptance_length {acceptance_length} is not " + f"within {rtol * 100}% of {expected_acceptance_length}" + ) + + print( + f"Test passed! Expected AL: " + f"{expected_acceptance_length}, got {acceptance_length}" + ) diff --git a/vllm/v1/spec_decode/dynamic/generate_config.py b/vllm/v1/spec_decode/dynamic/generate_config.py index 369974b63cc0..d27bffacaf22 100644 --- a/vllm/v1/spec_decode/dynamic/generate_config.py +++ b/vllm/v1/spec_decode/dynamic/generate_config.py @@ -1,16 +1,21 @@ +import json +import time + from vllm.config.speculative import DynamicSpeculativeConfig from vllm.v1.spec_decode.dynamic.process_benchmark_results import parse_itl -from examples.offline_inference.spec_decode import main as spec_decode_main +from vllm.v1.spec_decode.offline import main as spec_decode_main from vllm.v1.spec_decode.dynamic.process_benchmark_results import parse_itl from vllm.v1.spec_decode.dynamic.profiling_client import run_benchmarks from vllm.utils.argparse_utils import FlexibleArgumentParser -from vllm.benchmarks.datasets import add_dataset_parser, get_samples +from vllm.benchmarks.datasets import add_dataset_parser def main(): parser = FlexibleArgumentParser() + add_dataset_parser(parser) + parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--draft-dir", type=str, default=None) - parser.add_argument("--method", type=str, default="vanilla") + parser.add_argument("--method", type=str, default="eagle", choices=["ngram", "eagle", "eagle3", "mtp"]) parser.add_argument( "--num-speculative-tokens-list", nargs="*", type=int, default=[1, 3, 5] ) @@ -25,13 +30,37 @@ def main(): parser.add_argument("--tp", type=int, default=1) parser.add_argument("--result-dir", type=str, default="./log/dynamic_sd") parser.add_argument("--extra-log-arg", type=str, default="") - + parser.add_argument("--prompt-lookup-max", type=int, default=5) + parser.add_argument("--prompt-lookup-min", type=int, default=2) + parser.add_argument("--max-model-len", type=int, default=16384) + parser.add_argument("--temp", type=float, default=0) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--top-k", type=int, default=-1) + parser.add_argument("--output-len", type=int, default=256) + parser.add_argument("--num-batches", type=int, default=20, help="Number of batches to run for each benchmark.") + parser.add_argument("--custom-mm-prompts", action="store_true") + args = parser.parse_args() + args.enable_chunked_prefill = True + args.enforce_eager = False + args.print_output = False + args.num_spec_tokens = max(args.num_speculative_tokens_list) + args.eagle_dir = args.draft_dir + args.result_dir = f"{args.result_dir}/tp-{args.tp}_temp-{args.temp}_top_p-{args.top_p}_top_k-{args.top_k}/{args.dataset_path}/" + + # print the args in pretty format + import pprint + pprint.pprint(vars(args)) + start = time.time() + # Step 1: get acceptance_rate_per_pos acceptance_length, acceptance_rate_per_pos = spec_decode_main(args) - + print(f"Acceptance length: {acceptance_length}") + print(f"Acceptance rate per position: {acceptance_rate_per_pos}") + print(f"✅ Step 1: obtained acceptance rate per position.") + # Step 2: generate benchmark data for vanilla and specified method for method in ["vanilla", args.method]: run_benchmarks( @@ -39,26 +68,79 @@ def main(): model_dir = args.model_dir, draft_dir = args.draft_dir, method = method, + prompt_lookup_max = args.prompt_lookup_max, + prompt_lookup_min = args.prompt_lookup_min, num_speculative_tokens_list = args.num_speculative_tokens_list, batch_size_list = args.batch_size_list, max_vllm_batch_size = args.max_vllm_batch_size, tp = args.tp, + temp = args.temp, + top_p = args.top_p, + top_k = args.top_k, + num_batches = args.num_batches, + dataset_name = args.dataset_name, + dataset_path = args.dataset_path, result_dir = args.result_dir, extra_log_arg = args.extra_log_arg ) + print(f"✅ Step 2: benchmark data generated for vanilla and {args.method}.") # Step 3: parse batch_stats from benchmark data - batch_stats = parse_itl(args.result_dir) - - # Step 4: create DynamicSpeculativeConfig + batch_stats = parse_itl(method=args.method, benchmark_path_parent=args.result_dir) + print(f"✅ Step 3: parsed batch statistics from benchmark data.") + + # Step 4: Save DynamicSpeculativeConfig to a json file dynamic_config = DynamicSpeculativeConfig( is_online=False, - max_num_speculative_tokens=len(acceptance_rate_per_pos) + 1, + max_num_speculative_tokens=len(acceptance_rate_per_pos), acceptance_rate_per_pos=acceptance_rate_per_pos, batch_stats=batch_stats, ) - # Step 5: save dynamic_config to a json file - import json with open(f"{args.result_dir}/dynamic_speculative_config.json", "w") as f: - dynamic_config.model_dump_json(f, indent=4) \ No newline at end of file + json.dump(dynamic_config.model_dump(), f, indent=4) + + print(f"✅ Step 4: config saved to {args.result_dir}/dynamic_speculative_config.json") + + end = time.time() + print(f"Total time taken: {end - start:.2f} seconds") + + +""" +time python3 vllm/v1/spec_decode/dynamic/generate_config.py \ + --method eagle \ + --model-dir 'meta-llama/Llama-3.1-8B-Instruct' \ + --draft-dir 'yuhuili/EAGLE-LLaMA3.1-Instruct-8B' \ + --tp 1 \ + --temp 0 \ + --top-p 1.0 \ + --top-k -1 \ + --max-vllm-batch-size 256 \ + --batch-size-list 1 4 16 64 256 \ + --num-speculative-tokens-list 1 3 5 \ + --num-batches 20 \ + --dataset-name hf \ + --dataset-path 'philschmid/mt-bench' \ + --no-oversample \ + --result-dir './log/dynamic_sd_test' + +# shorter version: +time python3 vllm/v1/spec_decode/dynamic/generate_config.py \ + --method eagle \ + --model-dir 'meta-llama/Llama-3.1-8B-Instruct' \ + --draft-dir 'yuhuili/EAGLE-LLaMA3.1-Instruct-8B' \ + --tp 1 \ + --temp 0 \ + --top-p 1.0 \ + --top-k -1 \ + --max-vllm-batch-size 256 \ + --batch-size-list 1 256 \ + --num-speculative-tokens-list 1 5 \ + --num-batches 20 \ + --dataset-name hf \ + --dataset-path 'philschmid/mt-bench' \ + --no-oversample \ + --result-dir './log/dynamic_sd_test' +""" +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/vllm/v1/spec_decode/dynamic/dynamic.py b/vllm/v1/spec_decode/dynamic/manager.py similarity index 87% rename from vllm/v1/spec_decode/dynamic/dynamic.py rename to vllm/v1/spec_decode/dynamic/manager.py index 91c3b4223f9e..a576421a7d43 100644 --- a/vllm/v1/spec_decode/dynamic/dynamic.py +++ b/vllm/v1/spec_decode/dynamic/manager.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.config.speculative import DynamicSpeculativeConfig +import json +from vllm.utils.argparse_utils import FlexibleArgumentParser # _DYNAMIC_STATS = { # "max_num_speculative_tokens": 7, @@ -220,15 +222,35 @@ def _compute_optimal_num_speculative_tokens(self, batch_size: int) -> int: return chosen_num_drafts -# python3 vllm/v1/spec_decode/dynamic.py +# python3 vllm/v1/spec_decode/dynamic/manager.py --dynamic-config-path log/dynamic_sd_test_2/tp-1_temp-0.0_top_p-1.0_top_k--1/philschmid/mt-bench/dynamic_speculative_config.json if __name__ == "__main__": - MAX_TEST_BS = 128 - dynamic_sd = DynamicSpeculativeDecodingManager( - dynamic_config=_DYNAMIC_STATS, - vllm_max_batch_size=MAX_TEST_BS, - vllm_num_speculative_tokens=7, + parser = FlexibleArgumentParser() + parser.add_argument( + "--dynamic-config-path", + type=str, + default=None, + help="Path to the dynamic speculative decoding config json file.", ) - for i in range(4, MAX_TEST_BS + 1, 4): + args = parser.parse_args() + + MAX_TEST_BS = 128 + if args.dynamic_config_path: + with open(args.dynamic_config_path) as f: + data = json.load(f) + + dynamic_config = DynamicSpeculativeConfig.model_validate(data) + dynamic_sd = DynamicSpeculativeDecodingManager( + dynamic_config=dynamic_config, + vllm_max_batch_size=max(dynamic_config.batch_stats.keys()), + vllm_num_speculative_tokens=dynamic_config.max_num_speculative_tokens, + ) + else: + dynamic_sd = DynamicSpeculativeDecodingManager( + dynamic_config=_DYNAMIC_STATS, + vllm_max_batch_size=MAX_TEST_BS, + vllm_num_speculative_tokens=7, + ) + for i in range(1, dynamic_sd.vllm_max_batch_size + 1, 4): print("\n====================================") print( f"bs: {i}, optimal num drafts: {dynamic_sd.get_optimal_num_speculative_tokens(i)}" diff --git a/vllm/v1/spec_decode/dynamic/process_benchmark_results.py b/vllm/v1/spec_decode/dynamic/process_benchmark_results.py index f0d2a11b685b..ac5548176713 100644 --- a/vllm/v1/spec_decode/dynamic/process_benchmark_results.py +++ b/vllm/v1/spec_decode/dynamic/process_benchmark_results.py @@ -5,7 +5,7 @@ import os import re -from vllm.v1.spec_decode.dynamic.online_profiling_client import EAGLE_FMT, NGRAM_FMT +from vllm.v1.spec_decode.dynamic.profiling_client import EAGLE_FMT, NGRAM_FMT def reverse_fmt(fmt_str): @@ -20,7 +20,7 @@ def reverse_fmt(fmt_str): EAGLE_FMT_REVERSE = reverse_fmt(EAGLE_FMT) -def parse_itl(args): +def parse_itl(method, benchmark_path_parent): """ DynamicSpeculativeConfig.batch_stats: dict The structure is as follows: @@ -32,13 +32,13 @@ def parse_itl(args): """ batch_stats = {} - for method in ["vanilla", args.sd_method]: + for method in ["vanilla", method]: # find the names of all log files in this folder - args.benchmark_path = os.path.join(args.benchmark_path_parent, method) + benchmark_path = os.path.join(benchmark_path_parent, method) all_log_files = [ f - for f in os.listdir(args.benchmark_path) - if os.path.isfile(os.path.join(args.benchmark_path, f)) + for f in os.listdir(benchmark_path) + if os.path.isfile(os.path.join(benchmark_path, f)) and f.endswith(".txt") ] @@ -57,7 +57,7 @@ def parse_itl(args): k = re.match(EAGLE_FMT_REVERSE, spec_config_str).groups()[0] # read the log file to get the itl - with open(os.path.join(args.benchmark_path, log_file)) as f: + with open(os.path.join(benchmark_path, log_file)) as f: data = json.load(f) itl = data["median_itl_ms"] @@ -67,17 +67,18 @@ def parse_itl(args): batch_stats[int(bs)][int(k)] = itl + print(json.dumps(batch_stats, indent=4)) return batch_stats """ python3 vllm/v1/spec_decode/process_benchmark_results.py \ - --sd-method eagle \ + --method eagle \ --benchmark-path-parent 'log/dynamic_sd/tp-1_temp-0_top_p-1/philschmid/mt-bench/' """ if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--sd-method", type=str, default=None) + parser.add_argument("--method", type=str, default=None) parser.add_argument( "--benchmark-path-parent", type=str, @@ -86,7 +87,7 @@ def parse_itl(args): ) args = parser.parse_args() - assert args.sd_method in ["ngram", "eagle"], "Invalid method specified." + assert args.method in ["ngram", "eagle", "eagle3", "mtp"], "Invalid method specified." - batch_stats = parse_itl(args) - print(json.dumps(batch_stats, indent=4)) + batch_stats = parse_itl(method=args.method, + benchmark_path_parent=args.benchmark_path_parent) diff --git a/vllm/v1/spec_decode/dynamic/profiling_client.py b/vllm/v1/spec_decode/dynamic/profiling_client.py index ff0580c226d3..294e4e72f7a9 100644 --- a/vllm/v1/spec_decode/dynamic/profiling_client.py +++ b/vllm/v1/spec_decode/dynamic/profiling_client.py @@ -12,12 +12,6 @@ ) -@dataclass -class Dataset: - name: str - config: list - - NGRAM_FMT = "min-{min}-max-{max}-k-{k}" EAGLE_FMT = "k-{k}" @@ -38,9 +32,24 @@ def run_command(command): print(e.stderr) -def run_benchmarks(dry_run, model_dir, draft_dir, method, - num_speculative_tokens_list, batch_size_list, - max_vllm_batch_size, tp, result_dir, extra_log_arg): +def run_benchmarks(dry_run, + model_dir, + draft_dir, + method, + prompt_lookup_max, + prompt_lookup_min, + num_speculative_tokens_list, + batch_size_list, + max_vllm_batch_size, + tp, + temp, + top_p, + top_k, + num_batches, + dataset_name, + dataset_path, + result_dir, + extra_log_arg): assert method in ["vanilla", "ngram", "eagle", "eagle3"], ( "invalid method specified" @@ -54,48 +63,24 @@ def run_benchmarks(dry_run, model_dir, draft_dir, method, setup_server() port = 9001 - all_sampling_profile = [ - {"temperature": 0, "topp": 1}, # greedy - ] - - # `num_batches` decides how many batches are sent for each concurrency. - # E.g., num_batches=20 and concurrency=4 means total 80 prompts are sent - # such that we send 20 batches of 4 prompts each. This ensures a consistent - # number of batches across different concurrencies. For e.g., if total - # samples is 80 then concurrency 1 will send 80 batches while concurrency 64 - # will send 2 batches only. - MTBENCH_CONFIG = [{"num_batches": 20}] - - all_bench_dataset = [ - Dataset(name="philschmid/mt-bench", config=MTBENCH_CONFIG), - ] - - assert all(len(ds.config) > 0 for ds in all_bench_dataset), ( - "Each dataset must have at least one config" - ) - - all_ngram_params = [ - {"min": 2, "max": 5, "k": k} for k in num_speculative_tokens_list - ] - all_eagle_params = num_speculative_tokens_list # ablation num_exp_run = 0 - spec_method = method + # collate all spec configs to run for a given method all_spec_config = [] - if spec_method == "ngram": - for ngram_params in all_ngram_params: + if method == "ngram": + for ngram_k in num_speculative_tokens_list: all_spec_config.append( { "method": "ngram", - "num_speculative_tokens": ngram_params["k"], - "prompt_lookup_max": ngram_params["max"], - "prompt_lookup_min": ngram_params["min"], + "num_speculative_tokens": ngram_k, + "prompt_lookup_max": prompt_lookup_max, + "prompt_lookup_min": prompt_lookup_min, } ) - elif spec_method == "eagle": - for eagle_k in all_eagle_params: + elif method == "eagle": + for eagle_k in num_speculative_tokens_list: all_spec_config.append( { "method": "eagle", @@ -109,6 +94,7 @@ def run_benchmarks(dry_run, model_dir, draft_dir, method, all_spec_config.append(None) for spec_config in all_spec_config: + # start server server_process = start_server( port=port, @@ -121,59 +107,53 @@ def run_benchmarks(dry_run, model_dir, draft_dir, method, # start client for bench_concurrency in batch_size_list: - for bench_dataset_object in all_bench_dataset: - bench_dataset = bench_dataset_object.name - for bench_config in bench_dataset_object.config: - for sampling_profile in all_sampling_profile: - bench_temperature = sampling_profile["temperature"] - bench_topp = sampling_profile["topp"] - - spec_config_str = "vanilla" - if spec_method == "ngram": - spec_config_str = NGRAM_FMT.format( - min=spec_config["prompt_lookup_min"], - max=spec_config["prompt_lookup_max"], - k=spec_config["num_speculative_tokens"], - ) - elif spec_method == "eagle": - spec_config_str = EAGLE_FMT.format( - k=spec_config["num_speculative_tokens"] - ) - - # dataset specific config - if "philschmid/mt-bench" in bench_dataset: - bench_config_str = "mt_bench" - num_prompts = ( - bench_config["num_batches"] * bench_concurrency - ) - bench_vllm_serve_config = f"--dataset-name hf --dataset-path {bench_dataset} --num-prompts {num_prompts}" # noqa E501 - - print( - f"Number of prompts in {bench_dataset}: {num_prompts}" - ) - - # create dir if not exists - # TODO: make the path shared with generate_config.py - result_dir = f"{result_dir}/tp-{tp}_temp-{bench_temperature}_top_p-{bench_topp}/{bench_dataset}/{spec_method}/" # noqa E501 - if not os.path.exists(result_dir): - os.makedirs(result_dir) - - cmd = f'''time vllm bench serve --port {port} --save-result --save-detailed \ - --model {model_dir} \ - --backend openai-chat \ - --endpoint /v1/chat/completions \ - {bench_vllm_serve_config} \ - --max-concurrency {bench_concurrency} \ - --temperature={bench_temperature} \ - --top-p={bench_topp} \ - --result-dir "{result_dir}" \ - --result-filename "{spec_config_str}_{bench_config_str}_bs-{bench_concurrency}_{extra_log_arg}.txt"''' - - print(cmd) - num_exp_run += 1 - - if not dry_run: - run_command(cmd) + spec_config_str = "vanilla" + if method == "ngram": + spec_config_str = NGRAM_FMT.format( + min=spec_config["prompt_lookup_min"], + max=spec_config["prompt_lookup_max"], + k=spec_config["num_speculative_tokens"], + ) + elif method == "eagle": + spec_config_str = EAGLE_FMT.format( + k=spec_config["num_speculative_tokens"] + ) + + # dataset specific config + if "philschmid/mt-bench" in dataset_path: + bench_config_str = "mt_bench" + + num_prompts = num_batches * bench_concurrency + bench_vllm_config = f"--dataset-name {dataset_name} --dataset-path {dataset_path} --num-prompts {num_prompts}" + + print( + f"Number of prompts in {dataset_path}: {num_prompts}" + ) + + # create dir if not exists + # TODO: make the path shared with generate_config.py + # result_dir = f"{result_dir}/tp-{tp}_temp-{temp}_top_p-{top_p}_top_k-{top_k}/{bench_dataset}/{method}/" # noqa E501 + final_result_dir = f"{result_dir}/{method}/" # noqa E501 + if not os.path.exists(final_result_dir): + os.makedirs(final_result_dir) + + cmd = f'''time vllm bench serve --port {port} --save-result --save-detailed \ + --model {model_dir} \ + --backend openai-chat \ + --endpoint /v1/chat/completions \ + {bench_vllm_config} \ + --max-concurrency {bench_concurrency} \ + --temperature={temp} \ + --top-p={top_p} \ + --top-k={top_k} \ + --result-dir "{final_result_dir}" \ + --result-filename "{spec_config_str}_{bench_config_str}_bs-{bench_concurrency}_{extra_log_arg}.txt"''' + + print(cmd) + num_exp_run += 1 + + if not dry_run: + run_command(cmd) # server teardown: kill server and any gpu processes kill_server(port, server_process) @@ -206,8 +186,9 @@ def run_benchmarks(dry_run, model_dir, draft_dir, method, ) parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--draft-dir", type=str, default=None) - # parser.add_argument("--method-list", nargs='*', type=str, default=["vanilla", "eagle"]) parser.add_argument("--method", type=str, default="vanilla") + parser.add_argument("--prompt-lookup-max", type=int, default=5) + parser.add_argument("--prompt-lookup-min", type=int, default=2) parser.add_argument( "--num-speculative-tokens-list", nargs="*", type=int, default=[1, 3, 5] ) @@ -220,6 +201,12 @@ def run_benchmarks(dry_run, model_dir, draft_dir, method, help="Max vllm server batch size (max concurrency)", ) parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--temp", type=float, default=0) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--top-k", type=int, default=-1) + parser.add_argument("--num-batches", type=int, default=20, help="Number of batches to run for each benchmark.") + parser.add_argument("--dataset-name", type=str, default="hf") + parser.add_argument("--dataset-path", type=str, default="philschmid/mt-bench") parser.add_argument("--result-dir", type=str, default="./log/dynamic_sd") parser.add_argument("--extra-log-arg", type=str, default="") args = parser.parse_args() @@ -229,9 +216,17 @@ def run_benchmarks(dry_run, model_dir, draft_dir, method, model_dir = args.model_dir, draft_dir = args.draft_dir, method = args.method, + prompt_lookup_max = args.prompt_lookup_max, + prompt_lookup_min = args.prompt_lookup_min, num_speculative_tokens_list = args.num_speculative_tokens_list, batch_size_list = args.batch_size_list, max_vllm_batch_size = args.max_vllm_batch_size, tp = args.tp, + temp = args.temp, + top_p = args.top_p, + top_k = args.top_k, + num_batches = args.num_batches, + dataset_name = args.dataset_name, + dataset_path = args.dataset_path, result_dir = args.result_dir, extra_log_arg = args.extra_log_arg) \ No newline at end of file diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 5acf224fa107..1238fe2c1a83 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -223,7 +223,7 @@ def _set_positions(self, num_tokens: int, positions: torch.Tensor): def propose( self, - optimal_num_speculative_tokens: Optional[int], + optimal_num_speculative_tokens: int | None, # [num_tokens] target_token_ids: torch.Tensor, # [num_tokens] or [3, num_tokens] when M-RoPE is enabled diff --git a/vllm/v1/spec_decode/offline.py b/vllm/v1/spec_decode/offline.py new file mode 100644 index 000000000000..9e9d30ed94f3 --- /dev/null +++ b/vllm/v1/spec_decode/offline.py @@ -0,0 +1,240 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from transformers import AutoTokenizer + +from vllm import LLM, SamplingParams +from vllm.benchmarks.datasets import add_dataset_parser, get_samples +from vllm.inputs import TokensPrompt +from vllm.v1.metrics.reader import Counter, Vector + +try: + from vllm.utils.argparse_utils import FlexibleArgumentParser +except ImportError: + from argparse import ArgumentParser as FlexibleArgumentParser + + +QUESTION = "What is the content of each image?" +IMAGE_URLS = [ + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/duck.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/lion.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/flycatcher.jpeg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/somefish.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/starfish.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/snail.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/thistle.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/husky.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/orangetabbycat.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/guineapig.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/rabbit.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/horsepony.jpg", +] + + +def get_custom_mm_prompts(num_prompts): + prompts = [] + for url in IMAGE_URLS: + prompts.append( + [ + {"type": "image_url", "image_url": {"url": url}}, + {"type": "text", "text": QUESTION}, + ] + ) + if num_prompts > len(IMAGE_URLS): + prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1) + + return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]] + + +def parse_args(): + parser = FlexibleArgumentParser() + add_dataset_parser(parser) + parser.add_argument("--test", action="store_true") + parser.add_argument( + "--method", + type=str, + default="eagle", + choices=["ngram", "eagle", "eagle3", "mtp"], + ) + parser.add_argument("--num-spec-tokens", type=int, default=2) + parser.add_argument("--prompt-lookup-max", type=int, default=5) + parser.add_argument("--prompt-lookup-min", type=int, default=2) + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--enforce-eager", action="store_true") + parser.add_argument("--enable-chunked-prefill", action="store_true") + parser.add_argument("--max-model-len", type=int, default=16384) + parser.add_argument("--temp", type=float, default=0) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--top-k", type=int, default=-1) + parser.add_argument("--print-output", action="store_true") + parser.add_argument("--output-len", type=int, default=256) + parser.add_argument("--model-dir", type=str, default=None) + parser.add_argument("--eagle-dir", type=str, default=None) + parser.add_argument("--custom-mm-prompts", action="store_true") + return parser.parse_args() + + +def main(args): + args.endpoint_type = "openai-chat" + + model_dir = args.model_dir + if args.model_dir is None: + if args.custom_mm_prompts: + raise ValueError( + "custom_mm_prompts requires mm based models" + "default llama3.1-8b-instruct is not mm based" + "please specify model_dir to give a mm based model" + ) + model_dir = "meta-llama/Llama-3.1-8B-Instruct" + tokenizer = AutoTokenizer.from_pretrained(model_dir) + args.custom_skip_chat_template = True + + if not args.custom_mm_prompts: + prompts = get_samples(args, tokenizer) + # add_special_tokens is False to avoid adding bos twice + # when using chat templates + prompt_ids = [ + tokenizer.encode(prompt.prompt, add_special_tokens=False) + for prompt in prompts + ] + else: + prompts = get_custom_mm_prompts(args.num_prompts) + + if args.method == "eagle" or args.method == "eagle3": + eagle_dir = args.eagle_dir + if args.method == "eagle" and eagle_dir is None: + eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" + + elif args.method == "eagle3" and eagle_dir is None: + eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" + speculative_config = { + "method": args.method, + "model": eagle_dir, + "num_speculative_tokens": args.num_spec_tokens, + } + elif args.method == "ngram": + speculative_config = { + "method": "ngram", + "num_speculative_tokens": args.num_spec_tokens, + "prompt_lookup_max": args.prompt_lookup_max, + "prompt_lookup_min": args.prompt_lookup_min, + } + elif args.method == "mtp": + speculative_config = { + "method": "mtp", + "num_speculative_tokens": args.num_spec_tokens, + } + else: + raise ValueError(f"unknown method: {args.method}") + + llm = LLM( + model=model_dir, + trust_remote_code=True, + tensor_parallel_size=args.tp, + enable_chunked_prefill=args.enable_chunked_prefill, + enforce_eager=args.enforce_eager, + gpu_memory_utilization=0.9, + speculative_config=speculative_config, + disable_log_stats=False, + max_model_len=args.max_model_len, + limit_mm_per_prompt={"image": 5}, + disable_chunked_mm_input=True, + ) + + sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) + if not args.custom_mm_prompts: + outputs = llm.generate( + [TokensPrompt(prompt_token_ids=x) for x in prompt_ids], + sampling_params=sampling_params, + ) + else: + outputs = llm.chat(prompts, sampling_params=sampling_params) + + # print the generated text + if args.print_output: + for output in outputs: + print("-" * 50) + print(f"prompt: {output.prompt}") + print(f"generated text: {output.outputs[0].text}") + print("-" * 50) + + metrics = llm.get_metrics() + + total_num_output_tokens = sum( + len(output.outputs[0].token_ids) for output in outputs + ) + num_drafts = 0 + num_draft_tokens = 0 + num_accepted_tokens = 0 + acceptance_counts = [0] * args.num_spec_tokens + for metric in metrics: + if metric.name == "vllm:spec_decode_num_drafts": + assert isinstance(metric, Counter) + num_drafts += metric.value + elif metric.name == "vllm:spec_decode_num_draft_tokens": + assert isinstance(metric, Counter) + num_draft_tokens += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens": + assert isinstance(metric, Counter) + num_accepted_tokens += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": + assert isinstance(metric, Vector) + for pos in range(len(metric.values)): + acceptance_counts[pos] += metric.values[pos] + + print("-" * 50) + print(f"total_num_output_tokens: {total_num_output_tokens}") + print(f"num_drafts: {num_drafts}") + print(f"num_draft_tokens: {num_draft_tokens}") + print(f"num_accepted_tokens: {num_accepted_tokens}") + acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1 + print(f"mean acceptance length: {acceptance_length:.2f}") + print("-" * 50) + + # print acceptance at each token position + acceptance_rate_per_pos = [] + for i in range(len(acceptance_counts)): + acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0 + print(f"acceptance at token {i}: {acceptance_rate:.2f}") + acceptance_rate_per_pos.append(acceptance_rate) + + return acceptance_length, acceptance_rate_per_pos + + +def entrypoint(): + args = parse_args() + acceptance_length, acceptance_rate_per_pos = main(args) + + if args.test: + # takes ~30s to run on 1xH100 + assert args.method in ["eagle", "eagle3"] + assert args.tp == 1 + assert args.num_spec_tokens == 3 + assert args.dataset_name == "hf" + assert args.dataset_path == "philschmid/mt-bench" + assert args.num_prompts == 80 + assert args.temp == 0 + assert args.top_p == 1.0 + assert args.top_k == -1 + assert args.enable_chunked_prefill + + # check acceptance length is within 2% of expected value + rtol = 0.02 + expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811 + + assert ( + acceptance_length <= (1 + rtol) * expected_acceptance_length + and acceptance_length >= (1 - rtol) * expected_acceptance_length + ), ( + f"acceptance_length {acceptance_length} is not " + f"within {rtol * 100}% of {expected_acceptance_length}" + ) + + print( + f"Test passed! Expected AL: " + f"{expected_acceptance_length}, got {acceptance_length}" + ) + + +if __name__ == "__main__": + entrypoint() \ No newline at end of file diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7774a8e3612a..060f4cc8b16e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -145,7 +145,7 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler -from vllm.v1.spec_decode.dynamic import DynamicSpeculativeDecodingManager +from vllm.v1.spec_decode.dynamic.manager import DynamicSpeculativeDecodingManager from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata From 618b5fdc7448351f16d2623188108d17c556b3c6 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Thu, 15 Jan 2026 02:24:02 +0000 Subject: [PATCH 06/39] load dynamic sd config Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/config/speculative.py | 15 +++++++++++++-- vllm/v1/worker/gpu_model_runner.py | 4 ++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index ca047712ccff..a144f88ce59d 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -157,8 +157,8 @@ class SpeculativeConfig: """The parallel configuration for the target model.""" # dynamic speculative decoding control - """Configuration for dynamic speculative decoding, if provided.""" - dynamic_config: DynamicSpeculativeConfig | None = None + """Path to config file for dynamic speculative decoding, if provided.""" + dynamic_config_path: str | None = None # params generated in the post-init stage draft_model_config: SkipValidation[ModelConfig] = None # type: ignore @@ -503,6 +503,17 @@ def __post_init__(self): self.target_parallel_config, self.draft_tensor_parallel_size ) ) + + # load DynamicSpeculativeConfig: maybe use get_hf_file_to_dict() later + if self.dynamic_config_path is not None: + import json + with open(self.dynamic_config_path) as f: + data = json.load(f) + + self.dynamic_config = DynamicSpeculativeConfig.model_validate(data) + else: + self.dynamic_config = None + return self def _validate_suffix_decoding(self): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 060f4cc8b16e..08f63df2068a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -467,6 +467,10 @@ def __init__( else: self.dynamic_sd_manager = None + # REMOVE + if self.dynamic_sd_manager: + print(f"_optimal_num_speculative_tokens: {self.dynamic_sd_manager._optimal_num_speculative_tokens}") + # Request states. self.requests: dict[str, CachedRequestState] = {} # NOTE(rob): num_prompt_logprobs only includes reqs From c540a75fb53d80d7d6db3e16abba52e3a773bdfd Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Thu, 15 Jan 2026 02:25:00 +0000 Subject: [PATCH 07/39] remove offline bkp Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- examples/offline_inference/spec_decode_bkp.py | 236 ------------------ 1 file changed, 236 deletions(-) delete mode 100644 examples/offline_inference/spec_decode_bkp.py diff --git a/examples/offline_inference/spec_decode_bkp.py b/examples/offline_inference/spec_decode_bkp.py deleted file mode 100644 index 8f3b8670e2cb..000000000000 --- a/examples/offline_inference/spec_decode_bkp.py +++ /dev/null @@ -1,236 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from transformers import AutoTokenizer - -from vllm import LLM, SamplingParams -from vllm.benchmarks.datasets import add_dataset_parser, get_samples -from vllm.inputs import TokensPrompt -from vllm.v1.metrics.reader import Counter, Vector - -try: - from vllm.utils.argparse_utils import FlexibleArgumentParser -except ImportError: - from argparse import ArgumentParser as FlexibleArgumentParser - - -QUESTION = "What is the content of each image?" -IMAGE_URLS = [ - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/duck.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/lion.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/flycatcher.jpeg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/somefish.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/starfish.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/snail.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/thistle.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/husky.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/orangetabbycat.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/guineapig.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/rabbit.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/horsepony.jpg", -] - - -def get_custom_mm_prompts(num_prompts): - prompts = [] - for url in IMAGE_URLS: - prompts.append( - [ - {"type": "image_url", "image_url": {"url": url}}, - {"type": "text", "text": QUESTION}, - ] - ) - if num_prompts > len(IMAGE_URLS): - prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1) - - return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]] - - -def parse_args(): - parser = FlexibleArgumentParser() - add_dataset_parser(parser) - parser.add_argument("--test", action="store_true") - parser.add_argument( - "--method", - type=str, - default="eagle", - choices=["ngram", "eagle", "eagle3", "mtp"], - ) - parser.add_argument("--num-spec-tokens", type=int, default=2) - parser.add_argument("--prompt-lookup-max", type=int, default=5) - parser.add_argument("--prompt-lookup-min", type=int, default=2) - parser.add_argument("--tp", type=int, default=1) - parser.add_argument("--enforce-eager", action="store_true") - parser.add_argument("--enable-chunked-prefill", action="store_true") - parser.add_argument("--max-model-len", type=int, default=16384) - parser.add_argument("--temp", type=float, default=0) - parser.add_argument("--top-p", type=float, default=1.0) - parser.add_argument("--top-k", type=int, default=-1) - parser.add_argument("--print-output", action="store_true") - parser.add_argument("--output-len", type=int, default=256) - parser.add_argument("--model-dir", type=str, default=None) - parser.add_argument("--eagle-dir", type=str, default=None) - parser.add_argument("--custom-mm-prompts", action="store_true") - return parser.parse_args() - - -def main(args): - args.endpoint_type = "openai-chat" - - model_dir = args.model_dir - if args.model_dir is None: - if args.custom_mm_prompts: - raise ValueError( - "custom_mm_prompts requires mm based models" - "default llama3.1-8b-instruct is not mm based" - "please specify model_dir to give a mm based model" - ) - model_dir = "meta-llama/Llama-3.1-8B-Instruct" - tokenizer = AutoTokenizer.from_pretrained(model_dir) - args.custom_skip_chat_template = True - - if not args.custom_mm_prompts: - prompts = get_samples(args, tokenizer) - # add_special_tokens is False to avoid adding bos twice - # when using chat templates - prompt_ids = [ - tokenizer.encode(prompt.prompt, add_special_tokens=False) - for prompt in prompts - ] - else: - prompts = get_custom_mm_prompts(args.num_prompts) - - if args.method == "eagle" or args.method == "eagle3": - eagle_dir = args.eagle_dir - if args.method == "eagle" and eagle_dir is None: - eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" - - elif args.method == "eagle3" and eagle_dir is None: - eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" - speculative_config = { - "method": args.method, - "model": eagle_dir, - "num_speculative_tokens": args.num_spec_tokens, - } - elif args.method == "ngram": - speculative_config = { - "method": "ngram", - "num_speculative_tokens": args.num_spec_tokens, - "prompt_lookup_max": args.prompt_lookup_max, - "prompt_lookup_min": args.prompt_lookup_min, - } - elif args.method == "mtp": - speculative_config = { - "method": "mtp", - "num_speculative_tokens": args.num_spec_tokens, - } - else: - raise ValueError(f"unknown method: {args.method}") - - llm = LLM( - model=model_dir, - trust_remote_code=True, - tensor_parallel_size=args.tp, - enable_chunked_prefill=args.enable_chunked_prefill, - enforce_eager=args.enforce_eager, - gpu_memory_utilization=0.9, - speculative_config=speculative_config, - disable_log_stats=False, - max_model_len=args.max_model_len, - limit_mm_per_prompt={"image": 5}, - disable_chunked_mm_input=True, - ) - - sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) - if not args.custom_mm_prompts: - outputs = llm.generate( - [TokensPrompt(prompt_token_ids=x) for x in prompt_ids], - sampling_params=sampling_params, - ) - else: - outputs = llm.chat(prompts, sampling_params=sampling_params) - - # print the generated text - if args.print_output: - for output in outputs: - print("-" * 50) - print(f"prompt: {output.prompt}") - print(f"generated text: {output.outputs[0].text}") - print("-" * 50) - - metrics = llm.get_metrics() - - total_num_output_tokens = sum( - len(output.outputs[0].token_ids) for output in outputs - ) - num_drafts = 0 - num_draft_tokens = 0 - num_accepted_tokens = 0 - acceptance_counts = [0] * args.num_spec_tokens - for metric in metrics: - if metric.name == "vllm:spec_decode_num_drafts": - assert isinstance(metric, Counter) - num_drafts += metric.value - elif metric.name == "vllm:spec_decode_num_draft_tokens": - assert isinstance(metric, Counter) - num_draft_tokens += metric.value - elif metric.name == "vllm:spec_decode_num_accepted_tokens": - assert isinstance(metric, Counter) - num_accepted_tokens += metric.value - elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": - assert isinstance(metric, Vector) - for pos in range(len(metric.values)): - acceptance_counts[pos] += metric.values[pos] - - print("-" * 50) - print(f"total_num_output_tokens: {total_num_output_tokens}") - print(f"num_drafts: {num_drafts}") - print(f"num_draft_tokens: {num_draft_tokens}") - print(f"num_accepted_tokens: {num_accepted_tokens}") - acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1 - print(f"mean acceptance length: {acceptance_length:.2f}") - print("-" * 50) - - # print acceptance at each token position - acceptance_rate_per_pos = [] - for i in range(len(acceptance_counts)): - acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0 - print(f"acceptance at token {i}: {acceptance_rate:.2f}") - acceptance_rate_per_pos.append(acceptance_rate) - - return acceptance_length, acceptance_rate_per_pos - - -if __name__ == "__main__": - args = parse_args() - acceptance_length, acceptance_rate_per_pos = main(args) - - if args.test: - # takes ~30s to run on 1xH100 - assert args.method in ["eagle", "eagle3"] - assert args.tp == 1 - assert args.num_spec_tokens == 3 - assert args.dataset_name == "hf" - assert args.dataset_path == "philschmid/mt-bench" - assert args.num_prompts == 80 - assert args.temp == 0 - assert args.top_p == 1.0 - assert args.top_k == -1 - assert args.enable_chunked_prefill - - # check acceptance length is within 2% of expected value - rtol = 0.02 - expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811 - - assert ( - acceptance_length <= (1 + rtol) * expected_acceptance_length - and acceptance_length >= (1 - rtol) * expected_acceptance_length - ), ( - f"acceptance_length {acceptance_length} is not " - f"within {rtol * 100}% of {expected_acceptance_length}" - ) - - print( - f"Test passed! Expected AL: " - f"{expected_acceptance_length}, got {acceptance_length}" - ) From dfb2b3125911f08c1352a362d4e53d55c4fa4ef8 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Thu, 15 Jan 2026 02:30:19 +0000 Subject: [PATCH 08/39] remove Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/spec_decode/dynamic/manager.py | 73 ++----------------- .../spec_decode/dynamic/profiling_server.py | 15 ---- 2 files changed, 8 insertions(+), 80 deletions(-) diff --git a/vllm/v1/spec_decode/dynamic/manager.py b/vllm/v1/spec_decode/dynamic/manager.py index a576421a7d43..54292a8f635e 100644 --- a/vllm/v1/spec_decode/dynamic/manager.py +++ b/vllm/v1/spec_decode/dynamic/manager.py @@ -4,74 +4,17 @@ import json from vllm.utils.argparse_utils import FlexibleArgumentParser -# _DYNAMIC_STATS = { -# "max_num_speculative_tokens": 7, -# "acceptance_rate_per_pos": [0.68, 0.39, 0.20, 0.10, 0.06, 0.03, 0.02], # E1 -# # "acceptance_rate_per_pos": [0.76, 0.54, 0.39, 0.28, 0.21, 0.15, 0.12], # E3 -# # "acceptance_rate_per_pos": [0.17, 0.15, 0.12, 0.01, 0.01, 0.01, 0.01], # E1 - low -# "batch_stats": { -# 1: { 0: 6.87, 1: 7.97, 3: 9.41, 4: 9.91, 5: 10.8, 7: 12.29, }, -# 4: { 0: 6.87, 1: 7.97, 3: 9.41, 4: 9.91, 5: 10.8, 7: 12.29, }, -# 16: { 0: 7.3, 1: 8.39, 3: 9.95, 4: 10.8, 5: 11.59, 7: 13.11, }, -# 32: { 0: 7.64, 1: 8.97, 3: 10.78, 4: 11.79, 5: 12.81, 7: 14.86, }, -# 64: { 0: 8.53, 1: 10.44, 3: 13.16, 4: 15.7, 5: 17.54, 7: 120.57, }, -# } -# } -_DYNAMIC_STATS = DynamicSpeculativeConfig( +_DYNAMIC_STATS_TEST = DynamicSpeculativeConfig( max_num_speculative_tokens=7, acceptance_rate_per_pos=[0.68, 0.39, 0.20, 0.10, 0.06, 0.03, 0.02], # E1 - # acceptance_rate_per_pos=[0.76, 0.54, 0.39, 0.28, 0.21, 0.15, 0.12], # E3 - # acceptance_rate_per_pos=[0.17, 0.15, 0.12, 0.01, 0.01, 0.01, 0.01], # E1 - low batch_stats={ - 1: { - 0: 6.87, - 1: 7.97, - 3: 9.41, - 4: 9.91, - 5: 10.8, - 7: 12.29, - }, - 4: { - 0: 6.87, - 1: 7.97, - 3: 9.41, - 4: 9.91, - 5: 10.8, - 7: 12.29, - }, - 16: { - 0: 7.3, - 1: 8.39, - 3: 9.95, - 4: 10.8, - 5: 11.59, - 7: 13.11, - }, - 32: { - 0: 7.64, - 1: 8.97, - 3: 10.78, - 4: 11.79, - 5: 12.81, - 7: 14.86, - }, - 64: { - 0: 8.53, - 1: 10.44, - 3: 13.16, - 4: 15.7, - 5: 17.54, - 7: 120.57, - }, - 128: { - 0: 8.53, - 1: 15.44, - 3: 25.16, - 4: 30.7, - 5: 37.54, - 7: 220.57, - }, # fake + 1: {0: 6.87, 1: 7.97, 3: 9.41, 4: 9.91, 5: 10.8, 7: 12.29,}, + 4: {0: 6.87, 1: 7.97, 3: 9.41, 4: 9.91, 5: 10.8, 7: 12.29,}, + 16: {0: 7.3, 1: 8.39, 3: 9.95, 4: 10.8, 5: 11.59, 7: 13.11,}, + 32: {0: 7.64, 1: 8.97, 3: 10.78, 4: 11.79, 5: 12.81, 7: 14.86,}, + 64: {0: 8.53, 1: 10.44, 3: 13.16, 4: 15.7, 5: 17.54, 7: 120.57,}, + 128: {0: 8.53, 1: 15.44, 3: 25.16, 4: 30.7, 5: 37.54, 7: 220.57,}, # fake }, ) @@ -246,7 +189,7 @@ def _compute_optimal_num_speculative_tokens(self, batch_size: int) -> int: ) else: dynamic_sd = DynamicSpeculativeDecodingManager( - dynamic_config=_DYNAMIC_STATS, + dynamic_config=_DYNAMIC_STATS_TEST, vllm_max_batch_size=MAX_TEST_BS, vllm_num_speculative_tokens=7, ) diff --git a/vllm/v1/spec_decode/dynamic/profiling_server.py b/vllm/v1/spec_decode/dynamic/profiling_server.py index 3d82da653fbc..12641d52c16c 100644 --- a/vllm/v1/spec_decode/dynamic/profiling_server.py +++ b/vllm/v1/spec_decode/dynamic/profiling_server.py @@ -137,28 +137,13 @@ def start_server( return None -# def kill_server(port: int, server_process: subprocess.Popen | None): -# if server_process: -# server_process.kill() -# kill_gpu_processes(port) - - def kill_server(port, server_process): - # REMOVE - # print(f"Killing server on port {port}...") - if server_process: os.killpg(os.getpgid(server_process.pid), signal.SIGKILL) - # REMOVE - # print(f"Killed server process with PID: {server_process.pid if server_process else 'N/A'}") - wait_for_gpu_memory_to_clear() # Clean vLLM config config_path = os.path.expanduser("~/.config/vllm") if os.path.exists(config_path): subprocess.run(["rm", "-rf", config_path]) - - # REMOVE - # print(f"Killed server on port {port} and cleaned up GPU processes.") From bb65365d92205c7bed9376911768c71ae11ce275 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Thu, 15 Jan 2026 23:37:39 +0000 Subject: [PATCH 09/39] add runtime AL to goodput after warmup Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/core/sched/output.py | 4 ++ vllm/v1/core/sched/scheduler.py | 12 +++++ vllm/v1/spec_decode/dynamic/manager.py | 67 ++++++++++++++++++++++---- vllm/v1/spec_decode/metrics.py | 7 +++ vllm/v1/worker/gpu_model_runner.py | 14 ++---- 5 files changed, 83 insertions(+), 21 deletions(-) diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index b69fa87ebddc..fd129e417df3 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -8,6 +8,7 @@ from typing_extensions import deprecated from vllm._bc_linter import bc_linter_include +from vllm.v1.spec_decode.metrics import SpecDecodingStats if TYPE_CHECKING: import numpy as np @@ -207,6 +208,9 @@ class SchedulerOutput: # EC Cache Connector metadata ec_connector_metadata: ECConnectorMetadata | None = None + # Spec Decoding stats for all requests. + spec_decoding_stats_all: SpecDecodingStats | None = None + @classmethod def make_empty(cls) -> "SchedulerOutput": return cls( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 0111fd6e7198..6ae88ec06488 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -203,6 +203,7 @@ def __init__( if speculative_config.use_eagle(): self.use_eagle = True self.num_lookahead_tokens = self.num_spec_tokens + self.spec_decoding_stats_all = SpecDecodingStats.new(self.num_spec_tokens) # Create the KV cache manager. self.kv_cache_manager = KVCacheManager( @@ -739,8 +740,11 @@ def schedule(self) -> SchedulerOutput: # the previous and the current steps. finished_req_ids=self.finished_req_ids, free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), + spec_decoding_stats_all=self.spec_decoding_stats_all, ) + # REMOVE + # NOTE(Kuntai): this function is designed for multiple purposes: # 1. Plan the KV cache store # 2. Wrap up all the KV cache load / save ops into an opaque object @@ -1530,6 +1534,14 @@ def make_spec_decoding_stats( num_draft_tokens: int, num_accepted_tokens: int, ) -> SpecDecodingStats | None: + # Save this so its accessible by scheduler and can + # be sent to engine for Dynamic SD. + if self.spec_decoding_stats_all is not None: + self.spec_decoding_stats_all.observe_draft( + num_draft_tokens=num_draft_tokens, + num_accepted_tokens=num_accepted_tokens, + ) + if not self.log_stats: return None if spec_decoding_stats is None: diff --git a/vllm/v1/spec_decode/dynamic/manager.py b/vllm/v1/spec_decode/dynamic/manager.py index 54292a8f635e..c71baa137b1c 100644 --- a/vllm/v1/spec_decode/dynamic/manager.py +++ b/vllm/v1/spec_decode/dynamic/manager.py @@ -31,11 +31,15 @@ def __init__( dynamic_config: DynamicSpeculativeConfig | None, vllm_max_batch_size: int, vllm_num_speculative_tokens: int, + warmup_steps: int = 10, # TODO: make this configurable ): self.dynamic_config = dynamic_config self.vllm_max_batch_size = vllm_max_batch_size + self.vllm_num_speculative_tokens = vllm_num_speculative_tokens self.batch_stats = self.dynamic_config.batch_stats self.available_batch_sizes = sorted(self.dynamic_config.batch_stats.keys()) + self.steps = 0 + self.warmup_steps = warmup_steps # Sanity check assert ( @@ -76,6 +80,42 @@ def __init__( self.update_optimal_num_speculative_tokens() + def step(self, spec_decoding_stats_all, batch_size: int) -> int: + self.steps += 1 + if self.should_update(): + acceptance_rate_per_pos = self.compute_acceptance_rate_per_pos(spec_decoding_stats_all) + self.update_acceptance_rate_per_pos( + acceptance_rate_per_pos + ) + + optimal_num_speculative_tokens = ( + self.get_optimal_num_speculative_tokens( + batch_size + ) + ) + + return optimal_num_speculative_tokens + + def compute_acceptance_rate_per_pos(self, spec_decoding_stats_all) -> list[float]: + acceptance_rate_per_pos = [] + for i in range(self.vllm_num_speculative_tokens): + if spec_decoding_stats_all.num_draft_tokens_per_pos[i] == 0: + acceptance_rate = 0.0 + else: + acceptance_rate = ( + spec_decoding_stats_all.num_accepted_tokens_per_pos[i] + / spec_decoding_stats_all.num_draft_tokens_per_pos[i] + ) + + acceptance_rate_per_pos.append(acceptance_rate) + + return acceptance_rate_per_pos + + def should_update(self) -> bool: + # making this a separate function for easier overriding or extension + return self.steps > self.warmup_steps + + def get_optimal_num_speculative_tokens(self, batch_size: int) -> int: assert batch_size > 0, "batch_size must be > 0" assert batch_size <= self.vllm_max_batch_size, ( @@ -89,14 +129,21 @@ def update_optimal_num_speculative_tokens(self): for bs in range(1, self.vllm_max_batch_size + 1) } + def update_acceptance_rate_per_pos( + self, acceptance_rate_per_pos: list[float] + ): + self.dynamic_config.acceptance_rate_per_pos = acceptance_rate_per_pos + self.update_optimal_num_speculative_tokens() + + def _get_batch_stats(self, batch_size: int) -> dict: # import pdb; pdb.set_trace() if batch_size not in self.batch_stats: # find the nearest batch size smaller and bigger than the given batch size # and return the weighted avg of their stats - print( - f"Finding batch stats for batch_size: {batch_size} in self.available_batch_sizes: {self.available_batch_sizes}" - ) + # print( + # f"Finding batch stats for batch_size: {batch_size} in self.available_batch_sizes: {self.available_batch_sizes}" + # ) smaller_bs = [bs for bs in self.available_batch_sizes if bs < batch_size] smaller_bs = ( @@ -108,9 +155,9 @@ def _get_batch_stats(self, batch_size: int) -> dict: ) # REMOVE - print( - f"smaller_bs: {smaller_bs}, larger_bs: {larger_bs}, batch_size: {batch_size}" - ) + # print( + # f"smaller_bs: {smaller_bs}, larger_bs: {larger_bs}, batch_size: {batch_size}" + # ) smaller_bs_stat = self.batch_stats[smaller_bs] larger_bs_stat = self.batch_stats[larger_bs] @@ -118,7 +165,7 @@ def _get_batch_stats(self, batch_size: int) -> dict: ratio = (batch_size - smaller_bs) / (larger_bs - smaller_bs) # REMOVE - print(f"ratio: {ratio}") + # print(f"ratio: {ratio}") avg_stat: dict[int, float] = {} for k in smaller_bs_stat: @@ -158,9 +205,9 @@ def _compute_optimal_num_speculative_tokens(self, batch_size: int) -> int: chosen_num_drafts = num_drafts # REMOVE - print( - f"num_drafts: {num_drafts}, al: {curr_al}, itl: {curr_itl}, goodput: {curr_goodput}" - ) + # print( + # f"num_drafts: {num_drafts}, al: {curr_al}, itl: {curr_itl}, goodput: {curr_goodput}" + # ) return chosen_num_drafts diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index 6c16bc686d16..3dfc62359a08 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -27,12 +27,14 @@ class SpecDecodingStats: num_draft_tokens: int = 0 num_accepted_tokens: int = 0 num_accepted_tokens_per_pos: list[int] = field(default_factory=list) + num_draft_tokens_per_pos: list[int] = field(default_factory=list) @classmethod def new(cls, num_spec_tokens: int) -> "SpecDecodingStats": return cls( num_spec_tokens=num_spec_tokens, num_accepted_tokens_per_pos=[0] * num_spec_tokens, + num_draft_tokens_per_pos=[0] * num_spec_tokens, ) def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int): @@ -42,6 +44,11 @@ def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int): assert num_accepted_tokens <= self.num_spec_tokens for i in range(num_accepted_tokens): self.num_accepted_tokens_per_pos[i] += 1 + for i in range(num_draft_tokens): + self.num_draft_tokens_per_pos[i] += 1 + + # REMOVE + # print(f"self.num_drafts: {self.num_drafts}, num_draft_tokens: {num_draft_tokens}, num_accepted_tokens: {num_accepted_tokens}") class SpecDecodingLogging: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 08f63df2068a..8ce5bea7015f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3605,17 +3605,9 @@ def propose_draft_token_ids( ) -> list[list[int]] | torch.Tensor: optimal_num_speculative_tokens = None if self.dynamic_sd_manager: - batch_size = self.input_batch.num_reqs - optimal_num_speculative_tokens = ( - self.dynamic_sd_manager.get_optimal_num_speculative_tokens( - self.input_batch.num_reqs - ) - ) - - # REMOVE - print( - f"Batch size: {batch_size}, " - f"Optimal num speculative tokens: {optimal_num_speculative_tokens}" + optimal_num_speculative_tokens = self.dynamic_sd_manager.step( + scheduler_output.spec_decoding_stats_all, + self.input_batch.num_reqs, ) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens From 5fcf59edb7a11d13a056c9fa4bb98e8680859384 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Sat, 7 Feb 2026 23:27:59 -0500 Subject: [PATCH 10/39] Update vllm/v1/spec_decode/dynamic/manager.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/spec_decode/dynamic/manager.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/v1/spec_decode/dynamic/manager.py b/vllm/v1/spec_decode/dynamic/manager.py index c71baa137b1c..9d1ccf6b0ac1 100644 --- a/vllm/v1/spec_decode/dynamic/manager.py +++ b/vllm/v1/spec_decode/dynamic/manager.py @@ -162,6 +162,9 @@ def _get_batch_stats(self, batch_size: int) -> dict: smaller_bs_stat = self.batch_stats[smaller_bs] larger_bs_stat = self.batch_stats[larger_bs] + if larger_bs == smaller_bs: + return self.batch_stats[smaller_bs] + ratio = (batch_size - smaller_bs) / (larger_bs - smaller_bs) # REMOVE From 409eb69f40431ebd83726a8114bc974487d79206 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Sun, 8 Feb 2026 04:59:49 +0000 Subject: [PATCH 11/39] revert offline decoder to save loc diff Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- examples/offline_inference/spec_decode.py | 234 +++++++++++++++++++++- 1 file changed, 232 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index df64f75e47b0..6b7783545fb0 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -1,4 +1,234 @@ -from vllm.v1.spec_decode.offline import entrypoint as spec_decode_main +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from transformers import AutoTokenizer + +from vllm import LLM, SamplingParams +from vllm.benchmarks.datasets import add_dataset_parser, get_samples +from vllm.inputs import TokensPrompt +from vllm.v1.metrics.reader import Counter, Vector + +try: + from vllm.utils.argparse_utils import FlexibleArgumentParser +except ImportError: + from argparse import ArgumentParser as FlexibleArgumentParser + + +QUESTION = "What is the content of each image?" +IMAGE_URLS = [ + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/duck.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/lion.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/flycatcher.jpeg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/somefish.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/starfish.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/snail.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/thistle.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/husky.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/orangetabbycat.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/guineapig.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/rabbit.jpg", + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/horsepony.jpg", +] + + +def get_custom_mm_prompts(num_prompts): + prompts = [] + for url in IMAGE_URLS: + prompts.append( + [ + {"type": "image_url", "image_url": {"url": url}}, + {"type": "text", "text": QUESTION}, + ] + ) + if num_prompts > len(IMAGE_URLS): + prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1) + + return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]] + + +def parse_args(): + parser = FlexibleArgumentParser() + add_dataset_parser(parser) + parser.add_argument("--test", action="store_true") + parser.add_argument( + "--method", + type=str, + default="eagle", + choices=["ngram", "eagle", "eagle3", "mtp"], + ) + parser.add_argument("--num-spec-tokens", type=int, default=2) + parser.add_argument("--prompt-lookup-max", type=int, default=5) + parser.add_argument("--prompt-lookup-min", type=int, default=2) + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--enforce-eager", action="store_true") + parser.add_argument("--enable-chunked-prefill", action="store_true") + parser.add_argument("--max-model-len", type=int, default=16384) + parser.add_argument("--temp", type=float, default=0) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--top-k", type=int, default=-1) + parser.add_argument("--print-output", action="store_true") + parser.add_argument("--output-len", type=int, default=256) + parser.add_argument("--model-dir", type=str, default=None) + parser.add_argument("--eagle-dir", type=str, default=None) + parser.add_argument("--custom-mm-prompts", action="store_true") + return parser.parse_args() + + +def main(args): + args.endpoint_type = "openai-chat" + + model_dir = args.model_dir + if args.model_dir is None: + if args.custom_mm_prompts: + raise ValueError( + "custom_mm_prompts requires mm based models" + "default llama3.1-8b-instruct is not mm based" + "please specify model_dir to give a mm based model" + ) + model_dir = "meta-llama/Llama-3.1-8B-Instruct" + tokenizer = AutoTokenizer.from_pretrained(model_dir) + args.custom_skip_chat_template = True + + if not args.custom_mm_prompts: + prompts = get_samples(args, tokenizer) + # add_special_tokens is False to avoid adding bos twice + # when using chat templates + prompt_ids = [ + tokenizer.encode(prompt.prompt, add_special_tokens=False) + for prompt in prompts + ] + else: + prompts = get_custom_mm_prompts(args.num_prompts) + + if args.method == "eagle" or args.method == "eagle3": + eagle_dir = args.eagle_dir + if args.method == "eagle" and eagle_dir is None: + eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" + + elif args.method == "eagle3" and eagle_dir is None: + eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" + speculative_config = { + "method": args.method, + "model": eagle_dir, + "num_speculative_tokens": args.num_spec_tokens, + } + elif args.method == "ngram": + speculative_config = { + "method": "ngram", + "num_speculative_tokens": args.num_spec_tokens, + "prompt_lookup_max": args.prompt_lookup_max, + "prompt_lookup_min": args.prompt_lookup_min, + } + elif args.method == "mtp": + speculative_config = { + "method": "mtp", + "num_speculative_tokens": args.num_spec_tokens, + } + else: + raise ValueError(f"unknown method: {args.method}") + + llm = LLM( + model=model_dir, + trust_remote_code=True, + tensor_parallel_size=args.tp, + enable_chunked_prefill=args.enable_chunked_prefill, + enforce_eager=args.enforce_eager, + gpu_memory_utilization=0.9, + speculative_config=speculative_config, + disable_log_stats=False, + max_model_len=args.max_model_len, + limit_mm_per_prompt={"image": 5}, + disable_chunked_mm_input=True, + ) + + sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) + if not args.custom_mm_prompts: + outputs = llm.generate( + [TokensPrompt(prompt_token_ids=x) for x in prompt_ids], + sampling_params=sampling_params, + ) + else: + outputs = llm.chat(prompts, sampling_params=sampling_params) + + # print the generated text + if args.print_output: + for output in outputs: + print("-" * 50) + print(f"prompt: {output.prompt}") + print(f"generated text: {output.outputs[0].text}") + print("-" * 50) + + metrics = llm.get_metrics() + + total_num_output_tokens = sum( + len(output.outputs[0].token_ids) for output in outputs + ) + num_drafts = 0 + num_draft_tokens = 0 + num_accepted_tokens = 0 + acceptance_counts = [0] * args.num_spec_tokens + for metric in metrics: + if metric.name == "vllm:spec_decode_num_drafts": + assert isinstance(metric, Counter) + num_drafts += metric.value + elif metric.name == "vllm:spec_decode_num_draft_tokens": + assert isinstance(metric, Counter) + num_draft_tokens += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens": + assert isinstance(metric, Counter) + num_accepted_tokens += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": + assert isinstance(metric, Vector) + for pos in range(len(metric.values)): + acceptance_counts[pos] += metric.values[pos] + + print("-" * 50) + print(f"total_num_output_tokens: {total_num_output_tokens}") + print(f"num_drafts: {num_drafts}") + print(f"num_draft_tokens: {num_draft_tokens}") + print(f"num_accepted_tokens: {num_accepted_tokens}") + acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1 + print(f"mean acceptance length: {acceptance_length:.2f}") + print("-" * 50) + + # print acceptance at each token position + for i in range(len(acceptance_counts)): + acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0 + print(f"acceptance at token {i}: {acceptance_rate:.2f}") + + return acceptance_length + if __name__ == "__main__": - spec_decode_main() \ No newline at end of file + args = parse_args() + acceptance_length = main(args) + + if args.test: + # takes ~30s to run on 1xH100 + assert args.method in ["eagle", "eagle3"] + assert args.tp == 1 + assert args.num_spec_tokens == 3 + assert args.dataset_name == "hf" + assert args.dataset_path == "philschmid/mt-bench" + assert args.num_prompts == 80 + assert args.temp == 0 + assert args.top_p == 1.0 + assert args.top_k == -1 + assert args.enable_chunked_prefill + + # check acceptance length is within 2% of expected value + rtol = 0.02 + expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811 + + assert ( + acceptance_length <= (1 + rtol) * expected_acceptance_length + and acceptance_length >= (1 - rtol) * expected_acceptance_length + ), ( + f"acceptance_length {acceptance_length} is not " + f"within {rtol * 100}% of {expected_acceptance_length}" + ) + + print( + f"Test passed! Expected AL: " + f"{expected_acceptance_length}, got {acceptance_length}" + ) \ No newline at end of file From d7a149fde0025bad2bc029e85abf5ab3e1f76792 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Sun, 8 Feb 2026 05:33:36 +0000 Subject: [PATCH 12/39] refactor Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- .../v1/spec_decode/dynamic/generate_config.py | 248 ++++++++++++++---- .../dynamic/process_benchmark_results.py | 5 +- 2 files changed, 205 insertions(+), 48 deletions(-) diff --git a/vllm/v1/spec_decode/dynamic/generate_config.py b/vllm/v1/spec_decode/dynamic/generate_config.py index d27bffacaf22..3518af6a7655 100644 --- a/vllm/v1/spec_decode/dynamic/generate_config.py +++ b/vllm/v1/spec_decode/dynamic/generate_config.py @@ -1,26 +1,192 @@ import json +import pprint import time - +from pathlib import Path + +from vllm.benchmarks.sweep.param_sweep import ParameterSweep +from vllm.benchmarks.sweep.serve import SweepServeArgs, run_main from vllm.config.speculative import DynamicSpeculativeConfig -from vllm.v1.spec_decode.dynamic.process_benchmark_results import parse_itl from vllm.v1.spec_decode.offline import main as spec_decode_main -from vllm.v1.spec_decode.dynamic.process_benchmark_results import parse_itl -from vllm.v1.spec_decode.dynamic.profiling_client import run_benchmarks from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.benchmarks.datasets import add_dataset_parser + +def build_serve_params(method, draft_dir, tp, + num_speculative_tokens_list, + prompt_lookup_max, prompt_lookup_min): + """Build serve parameter sweep for vanilla + speculative decode configs. + + Each entry becomes a separate server configuration in the sweep. + The sweep framework starts/stops the server for each serve config. + """ + records = [] + + # Vanilla config (no speculative decoding) + records.append({"_benchmark_name": "vanilla"}) + + # Speculative decoding configs with varying num_speculative_tokens + if method == "ngram": + for k in num_speculative_tokens_list: + records.append({ + "_benchmark_name": f"ngram-k-{k}", + "speculative_config": { + "method": "ngram", + "num_speculative_tokens": k, + "prompt_lookup_max": prompt_lookup_max, + "prompt_lookup_min": prompt_lookup_min, + }, + }) + elif method in ("eagle", "eagle3"): + for k in num_speculative_tokens_list: + records.append({ + "_benchmark_name": f"{method}-k-{k}", + "speculative_config": { + "method": method, + "model": draft_dir, + "num_speculative_tokens": k, + "draft_tensor_parallel_size": tp, + }, + }) + elif method == "mtp": + for k in num_speculative_tokens_list: + records.append({ + "_benchmark_name": f"mtp-k-{k}", + "speculative_config": { + "method": "mtp", + "num_speculative_tokens": k, + }, + }) + + return ParameterSweep.from_records(records) + + +def build_bench_params(batch_size_list, num_batches): + """Build benchmark parameter sweep for different concurrencies. + + Each entry varies max_concurrency (batch size) and num_prompts. + """ + records = [] + for bs in batch_size_list: + num_prompts = num_batches * bs + records.append({ + "_benchmark_name": f"bs-{bs}", + "max_concurrency": bs, + "num_prompts": num_prompts, + }) + return ParameterSweep.from_records(records) + + +def parse_itl_from_dataframe(result_df): + """Parse ITL from sweep result DataFrame into batch_stats format. + + Returns: + batch_stats: dict of {batch_size: {num_drafts: median_itl_ms}} + where num_drafts=0 corresponds to vanilla (no speculation). + """ + batch_stats = {} + for _, row in result_df.iterrows(): + bs = int(row["max_concurrency"]) + + # Determine k (num speculative tokens) from speculative_config. + # For vanilla rows, speculative_config is NaN (missing from serve + # params); for spec decode rows, it's a dict or JSON string. + spec_config = row.get("speculative_config") + if isinstance(spec_config, dict): + k = int(spec_config["num_speculative_tokens"]) + elif isinstance(spec_config, str): + k = int(json.loads(spec_config)["num_speculative_tokens"]) + else: + k = 0 # vanilla (NaN or None) + + if bs not in batch_stats: + batch_stats[bs] = {} + batch_stats[bs][k] = row["median_itl_ms"] + + return batch_stats + + +def run_profiling_sweep(args): + """Run profiling benchmarks using vllm bench sweep serve. + + This replaces the custom profiling_client/profiling_server by leveraging + the existing vllm bench sweep serve utility which handles: + - Server lifecycle management (start, wait-for-ready, stop) + - Cartesian product of serve_params x bench_params + - Result saving and aggregation + """ + # Base serve command (static params shared across all serve configs) + serve_cmd = [ + "vllm", "serve", args.model_dir, + "--disable-log-requests", + "--gpu-memory-utilization", "0.95", + "--max-num-seqs", str(args.max_vllm_batch_size), + "--tensor-parallel-size", str(args.tp), + "--enable-chunked-prefill", + "--no-enable-prefix-caching", + ] + + # Base bench command (static params shared across all bench configs) + bench_cmd = [ + "vllm", "bench", "serve", + "--model", args.model_dir, + "--backend", "openai-chat", + "--endpoint", "/v1/chat/completions", + "--dataset-name", args.dataset_name, + "--dataset-path", args.dataset_path, + f"--temperature={args.temp}", + f"--top-p={args.top_p}", + f"--top-k={args.top_k}", + ] + + # Build parameter sweeps + serve_params = build_serve_params( + method=args.method, + draft_dir=args.draft_dir, + tp=args.tp, + num_speculative_tokens_list=args.num_speculative_tokens_list, + prompt_lookup_max=args.prompt_lookup_max, + prompt_lookup_min=args.prompt_lookup_min, + ) + + bench_params = build_bench_params( + batch_size_list=args.batch_size_list, + num_batches=args.num_batches, + ) + + # Run the sweep + sweep_args = SweepServeArgs( + serve_cmd=serve_cmd, + bench_cmd=bench_cmd, + after_bench_cmd=[], + show_stdout=True, + serve_params=serve_params, + bench_params=bench_params, + output_dir=Path(args.result_dir), + num_runs=1, + dry_run=False, + resume=None, + link_vars=[], + ) + + result_df = run_main(sweep_args) + return result_df + + def main(): parser = FlexibleArgumentParser() add_dataset_parser(parser) parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--draft-dir", type=str, default=None) - parser.add_argument("--method", type=str, default="eagle", choices=["ngram", "eagle", "eagle3", "mtp"]) + parser.add_argument("--method", type=str, default="eagle", + choices=["ngram", "eagle", "eagle3", "mtp"]) parser.add_argument( - "--num-speculative-tokens-list", nargs="*", type=int, default=[1, 3, 5] + "--num-speculative-tokens-list", nargs="*", type=int, + default=[1, 3, 5] ) parser.add_argument( - "--batch-size-list", nargs="*", type=int, default=[1, 4, 16, 64, 256] + "--batch-size-list", nargs="*", type=int, + default=[1, 4, 16, 64, 256] ) parser.add_argument( "--max-vllm-batch-size", @@ -29,7 +195,6 @@ def main(): ) parser.add_argument("--tp", type=int, default=1) parser.add_argument("--result-dir", type=str, default="./log/dynamic_sd") - parser.add_argument("--extra-log-arg", type=str, default="") parser.add_argument("--prompt-lookup-max", type=int, default=5) parser.add_argument("--prompt-lookup-min", type=int, default=2) parser.add_argument("--max-model-len", type=int, default=16384) @@ -37,9 +202,9 @@ def main(): parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-k", type=int, default=-1) parser.add_argument("--output-len", type=int, default=256) - parser.add_argument("--num-batches", type=int, default=20, help="Number of batches to run for each benchmark.") + parser.add_argument("--num-batches", type=int, default=20, + help="Number of batches to run for each benchmark.") parser.add_argument("--custom-mm-prompts", action="store_true") - args = parser.parse_args() args.enable_chunked_prefill = True @@ -47,48 +212,36 @@ def main(): args.print_output = False args.num_spec_tokens = max(args.num_speculative_tokens_list) args.eagle_dir = args.draft_dir - args.result_dir = f"{args.result_dir}/tp-{args.tp}_temp-{args.temp}_top_p-{args.top_p}_top_k-{args.top_k}/{args.dataset_path}/" - - # print the args in pretty format - import pprint + args.result_dir = (f"{args.result_dir}/tp-{args.tp}_temp-{args.temp}" + f"_top_p-{args.top_p}_top_k-{args.top_k}" + f"/{args.dataset_path}/") + + assert args.max_vllm_batch_size == max(args.batch_size_list), ( + "max_vllm_batch_size must be equal to max of batch_size_list" + ) + pprint.pprint(vars(args)) start = time.time() - # Step 1: get acceptance_rate_per_pos acceptance_length, acceptance_rate_per_pos = spec_decode_main(args) print(f"Acceptance length: {acceptance_length}") print(f"Acceptance rate per position: {acceptance_rate_per_pos}") - print(f"✅ Step 1: obtained acceptance rate per position.") - - # Step 2: generate benchmark data for vanilla and specified method - for method in ["vanilla", args.method]: - run_benchmarks( - dry_run = False, - model_dir = args.model_dir, - draft_dir = args.draft_dir, - method = method, - prompt_lookup_max = args.prompt_lookup_max, - prompt_lookup_min = args.prompt_lookup_min, - num_speculative_tokens_list = args.num_speculative_tokens_list, - batch_size_list = args.batch_size_list, - max_vllm_batch_size = args.max_vllm_batch_size, - tp = args.tp, - temp = args.temp, - top_p = args.top_p, - top_k = args.top_k, - num_batches = args.num_batches, - dataset_name = args.dataset_name, - dataset_path = args.dataset_path, - result_dir = args.result_dir, - extra_log_arg = args.extra_log_arg - ) + print("✅ Step 1: obtained acceptance rate per position.") + + # Step 2: generate benchmark data using vllm bench sweep serve + # This runs the Cartesian product of: + # serve_params: [vanilla, {method}-k-1, {method}-k-3, {method}-k-5] + # bench_params: [bs-1, bs-4, bs-16, bs-64, bs-256] + # The sweep framework handles server start/stop for each serve config. + result_df = run_profiling_sweep(args) print(f"✅ Step 2: benchmark data generated for vanilla and {args.method}.") # Step 3: parse batch_stats from benchmark data - batch_stats = parse_itl(method=args.method, benchmark_path_parent=args.result_dir) - print(f"✅ Step 3: parsed batch statistics from benchmark data.") - + batch_stats = parse_itl_from_dataframe(result_df) + print("✅ Step 3: parsed batch statistics from benchmark data.") + print(json.dumps(batch_stats, indent=4)) + # Step 4: Save DynamicSpeculativeConfig to a json file dynamic_config = DynamicSpeculativeConfig( is_online=False, @@ -97,10 +250,11 @@ def main(): batch_stats=batch_stats, ) - with open(f"{args.result_dir}/dynamic_speculative_config.json", "w") as f: + config_path = f"{args.result_dir}/dynamic_speculative_config.json" + with open(config_path, "w") as f: json.dump(dynamic_config.model_dump(), f, indent=4) - - print(f"✅ Step 4: config saved to {args.result_dir}/dynamic_speculative_config.json") + + print(f"✅ Step 4: config saved to {config_path}") end = time.time() print(f"Total time taken: {end - start:.2f} seconds") @@ -136,11 +290,11 @@ def main(): --max-vllm-batch-size 256 \ --batch-size-list 1 256 \ --num-speculative-tokens-list 1 5 \ - --num-batches 20 \ + --num-batches 5 \ --dataset-name hf \ --dataset-path 'philschmid/mt-bench' \ --no-oversample \ - --result-dir './log/dynamic_sd_test' + --result-dir './log/dynamic_sd_test_short' """ if __name__ == "__main__": main() \ No newline at end of file diff --git a/vllm/v1/spec_decode/dynamic/process_benchmark_results.py b/vllm/v1/spec_decode/dynamic/process_benchmark_results.py index ac5548176713..ff56e3240853 100644 --- a/vllm/v1/spec_decode/dynamic/process_benchmark_results.py +++ b/vllm/v1/spec_decode/dynamic/process_benchmark_results.py @@ -5,7 +5,10 @@ import os import re -from vllm.v1.spec_decode.dynamic.profiling_client import EAGLE_FMT, NGRAM_FMT + +# Format strings for speculative config naming in benchmark result files. +NGRAM_FMT = "min-{min}-max-{max}-k-{k}" +EAGLE_FMT = "k-{k}" def reverse_fmt(fmt_str): From 44aed5ea6bdf475bc56e878b23eff8dde3728775 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Sun, 8 Feb 2026 05:41:34 +0000 Subject: [PATCH 13/39] conflict Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/core/sched/scheduler.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 611a185bcbfc..9bea6036ea15 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1871,7 +1871,6 @@ def make_spec_decoding_stats( num_invalid_spec_tokens: dict[str, int] | None, request_id: str, ) -> SpecDecodingStats | None: -<<<<<<< HEAD # Save this so its accessible by scheduler and can # be sent to engine for Dynamic SD. if self.spec_decoding_stats_all is not None: @@ -1880,10 +1879,7 @@ def make_spec_decoding_stats( num_accepted_tokens=num_accepted_tokens, ) - if not self.log_stats: -======= if not self.log_stats or not num_draft_tokens: ->>>>>>> ab10d798555ee3611f82e71cbe573086fb92a4ed return None if spec_decoding_stats is None: spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens) From 7ed335368d3006bb3b8f15db65c781c26c9d3cbd Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Sun, 8 Feb 2026 05:57:53 +0000 Subject: [PATCH 14/39] refactor Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/config/speculative.py | 3 +- .../dynamic/process_benchmark_results.py | 96 -------- .../spec_decode/dynamic/profiling_client.py | 232 ------------------ .../spec_decode/dynamic/profiling_server.py | 149 ----------- 4 files changed, 1 insertion(+), 479 deletions(-) delete mode 100644 vllm/v1/spec_decode/dynamic/process_benchmark_results.py delete mode 100644 vllm/v1/spec_decode/dynamic/profiling_client.py delete mode 100644 vllm/v1/spec_decode/dynamic/profiling_server.py diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 82797cc1630e..212b0461f642 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -4,7 +4,7 @@ import ast from typing import TYPE_CHECKING, Any, Literal, get_args -from pydantic import Field, SkipValidation, model_validator +from pydantic import Field, SkipValidation, model_validator, BaseModel from typing_extensions import Self from vllm.config.model import ModelConfig @@ -53,7 +53,6 @@ ] -# @dataclass class DynamicSpeculativeConfig(BaseModel): # """A mapping from batch size to optimal number of drafts to use for that # batch size. This is used to dynamically adjust the number of drafts used diff --git a/vllm/v1/spec_decode/dynamic/process_benchmark_results.py b/vllm/v1/spec_decode/dynamic/process_benchmark_results.py deleted file mode 100644 index ff56e3240853..000000000000 --- a/vllm/v1/spec_decode/dynamic/process_benchmark_results.py +++ /dev/null @@ -1,96 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import argparse -import json -import os -import re - - -# Format strings for speculative config naming in benchmark result files. -NGRAM_FMT = "min-{min}-max-{max}-k-{k}" -EAGLE_FMT = "k-{k}" - - -def reverse_fmt(fmt_str): - # e.g., convert 'min-{min}-max-{max}-k-{k}' -> 'min-{}-max-{}-k-{}' - FMT = re.sub(r"\{[^}]+\}", "{}", fmt_str) - # e.g., convert 'min-{}-max-{}-k-{}' -> 'min-(.+)-max-(.+)-k-(.+)' - FMT = FMT.replace("{}", "(.+)") - return FMT - - -NGRAM_FMT_REVERSE = reverse_fmt(NGRAM_FMT) -EAGLE_FMT_REVERSE = reverse_fmt(EAGLE_FMT) - - -def parse_itl(method, benchmark_path_parent): - """ - DynamicSpeculativeConfig.batch_stats: dict - The structure is as follows: - { - batch_size: { - num_drafts: itl (i.e., inter token latency in ms) - } - } - """ - batch_stats = {} - - for method in ["vanilla", method]: - # find the names of all log files in this folder - benchmark_path = os.path.join(benchmark_path_parent, method) - all_log_files = [ - f - for f in os.listdir(benchmark_path) - if os.path.isfile(os.path.join(benchmark_path, f)) - and f.endswith(".txt") - ] - - # parse the log files to get the config params - for log_file in all_log_files: - # find bs - bs = re.search(r"_bs-(\d+)", log_file).group(1) - - # find sd params - spec_config_str = log_file.split("_")[0] - if method == "vanilla": - k = 0 - elif method == "ngram": - min, max, k = re.match(NGRAM_FMT_REVERSE, spec_config_str).groups() - elif method == "eagle": - k = re.match(EAGLE_FMT_REVERSE, spec_config_str).groups()[0] - - # read the log file to get the itl - with open(os.path.join(benchmark_path, log_file)) as f: - data = json.load(f) - itl = data["median_itl_ms"] - - # add to batch_stats - if int(bs) not in batch_stats: - batch_stats[int(bs)] = {} - - batch_stats[int(bs)][int(k)] = itl - - print(json.dumps(batch_stats, indent=4)) - return batch_stats - - -""" -python3 vllm/v1/spec_decode/process_benchmark_results.py \ - --method eagle \ - --benchmark-path-parent 'log/dynamic_sd/tp-1_temp-0_top_p-1/philschmid/mt-bench/' -""" -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--method", type=str, default=None) - parser.add_argument( - "--benchmark-path-parent", - type=str, - default=None, - help="Root folder which has the log files", - ) - - args = parser.parse_args() - assert args.method in ["ngram", "eagle", "eagle3", "mtp"], "Invalid method specified." - - batch_stats = parse_itl(method=args.method, - benchmark_path_parent=args.benchmark_path_parent) diff --git a/vllm/v1/spec_decode/dynamic/profiling_client.py b/vllm/v1/spec_decode/dynamic/profiling_client.py deleted file mode 100644 index 294e4e72f7a9..000000000000 --- a/vllm/v1/spec_decode/dynamic/profiling_client.py +++ /dev/null @@ -1,232 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import argparse -import os -import subprocess -from dataclasses import dataclass - -from vllm.v1.spec_decode.dynamic.profiling_server import ( - kill_server, - setup_server, - start_server, -) - - -NGRAM_FMT = "min-{min}-max-{max}-k-{k}" -EAGLE_FMT = "k-{k}" - - -def run_command(command): - try: - result = subprocess.run( - f"bash -c '{command}'", - shell=True, - check=True, - capture_output=True, - text=True, - ) - print("Output:") - print(result.stdout) - except subprocess.CalledProcessError as e: - print("Error:") - print(e.stderr) - - -def run_benchmarks(dry_run, - model_dir, - draft_dir, - method, - prompt_lookup_max, - prompt_lookup_min, - num_speculative_tokens_list, - batch_size_list, - max_vllm_batch_size, - tp, - temp, - top_p, - top_k, - num_batches, - dataset_name, - dataset_path, - result_dir, - extra_log_arg): - - assert method in ["vanilla", "ngram", "eagle", "eagle3"], ( - "invalid method specified" - ) - - assert max_vllm_batch_size == max(batch_size_list), ( - "max_vllm_batch_size must be equal to max of batch_size" - ) - - # setup server - setup_server() - - port = 9001 - - # ablation - num_exp_run = 0 - - # collate all spec configs to run for a given method - all_spec_config = [] - if method == "ngram": - for ngram_k in num_speculative_tokens_list: - all_spec_config.append( - { - "method": "ngram", - "num_speculative_tokens": ngram_k, - "prompt_lookup_max": prompt_lookup_max, - "prompt_lookup_min": prompt_lookup_min, - } - ) - elif method == "eagle": - for eagle_k in num_speculative_tokens_list: - all_spec_config.append( - { - "method": "eagle", - "model": draft_dir, - "num_speculative_tokens": eagle_k, - "draft_tensor_parallel_size": tp, - } - ) - else: - # vanilla case - all_spec_config.append(None) - - for spec_config in all_spec_config: - - # start server - server_process = start_server( - port=port, - target_model_dir=model_dir, - spec_config=spec_config, - tp=tp, - max_vllm_bs=max_vllm_batch_size, - dry_run=dry_run, - ) - - # start client - for bench_concurrency in batch_size_list: - spec_config_str = "vanilla" - if method == "ngram": - spec_config_str = NGRAM_FMT.format( - min=spec_config["prompt_lookup_min"], - max=spec_config["prompt_lookup_max"], - k=spec_config["num_speculative_tokens"], - ) - elif method == "eagle": - spec_config_str = EAGLE_FMT.format( - k=spec_config["num_speculative_tokens"] - ) - - # dataset specific config - if "philschmid/mt-bench" in dataset_path: - bench_config_str = "mt_bench" - - num_prompts = num_batches * bench_concurrency - bench_vllm_config = f"--dataset-name {dataset_name} --dataset-path {dataset_path} --num-prompts {num_prompts}" - - print( - f"Number of prompts in {dataset_path}: {num_prompts}" - ) - - # create dir if not exists - # TODO: make the path shared with generate_config.py - # result_dir = f"{result_dir}/tp-{tp}_temp-{temp}_top_p-{top_p}_top_k-{top_k}/{bench_dataset}/{method}/" # noqa E501 - final_result_dir = f"{result_dir}/{method}/" # noqa E501 - if not os.path.exists(final_result_dir): - os.makedirs(final_result_dir) - - cmd = f'''time vllm bench serve --port {port} --save-result --save-detailed \ - --model {model_dir} \ - --backend openai-chat \ - --endpoint /v1/chat/completions \ - {bench_vllm_config} \ - --max-concurrency {bench_concurrency} \ - --temperature={temp} \ - --top-p={top_p} \ - --top-k={top_k} \ - --result-dir "{final_result_dir}" \ - --result-filename "{spec_config_str}_{bench_config_str}_bs-{bench_concurrency}_{extra_log_arg}.txt"''' - - print(cmd) - num_exp_run += 1 - - if not dry_run: - run_command(cmd) - - # server teardown: kill server and any gpu processes - kill_server(port, server_process) - - print(f"Total number of experiments run: {num_exp_run}") - - -""" -# eagle -time python3 vllm/v1/spec_decode/online_profiling_client.py \ - --batch-size-list 1 4 16 64 256 \ - --num-speculative-tokens-list 1 3 5 \ - --max-vllm-batch-size 256 \ - --method eagle \ - --model-dir meta-llama/Llama-3.1-8B-Instruct \ - --draft-dir yuhuili/EAGLE-LLaMA3.1-Instruct-8B - -# vanilla -time python3 vllm/v1/spec_decode/online_profiling_client.py \ - --batch-size-list 1 4 16 64 256 \ - --max-vllm-batch-size 256 \ - --method vanilla -""" -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--dry-run", - action="store_true", - help="Run in dry run mode. If set, commands will be printed but not executed.", - ) - parser.add_argument("--model-dir", type=str, default=None) - parser.add_argument("--draft-dir", type=str, default=None) - parser.add_argument("--method", type=str, default="vanilla") - parser.add_argument("--prompt-lookup-max", type=int, default=5) - parser.add_argument("--prompt-lookup-min", type=int, default=2) - parser.add_argument( - "--num-speculative-tokens-list", nargs="*", type=int, default=[1, 3, 5] - ) - parser.add_argument( - "--batch-size-list", nargs="*", type=int, default=[1, 4, 16, 64, 256] - ) - parser.add_argument( - "--max-vllm-batch-size", - type=int, - help="Max vllm server batch size (max concurrency)", - ) - parser.add_argument("--tp", type=int, default=1) - parser.add_argument("--temp", type=float, default=0) - parser.add_argument("--top-p", type=float, default=1.0) - parser.add_argument("--top-k", type=int, default=-1) - parser.add_argument("--num-batches", type=int, default=20, help="Number of batches to run for each benchmark.") - parser.add_argument("--dataset-name", type=str, default="hf") - parser.add_argument("--dataset-path", type=str, default="philschmid/mt-bench") - parser.add_argument("--result-dir", type=str, default="./log/dynamic_sd") - parser.add_argument("--extra-log-arg", type=str, default="") - args = parser.parse_args() - - run_benchmarks( - dry_run = args.dry_run, - model_dir = args.model_dir, - draft_dir = args.draft_dir, - method = args.method, - prompt_lookup_max = args.prompt_lookup_max, - prompt_lookup_min = args.prompt_lookup_min, - num_speculative_tokens_list = args.num_speculative_tokens_list, - batch_size_list = args.batch_size_list, - max_vllm_batch_size = args.max_vllm_batch_size, - tp = args.tp, - temp = args.temp, - top_p = args.top_p, - top_k = args.top_k, - num_batches = args.num_batches, - dataset_name = args.dataset_name, - dataset_path = args.dataset_path, - result_dir = args.result_dir, - extra_log_arg = args.extra_log_arg) \ No newline at end of file diff --git a/vllm/v1/spec_decode/dynamic/profiling_server.py b/vllm/v1/spec_decode/dynamic/profiling_server.py deleted file mode 100644 index 12641d52c16c..000000000000 --- a/vllm/v1/spec_decode/dynamic/profiling_server.py +++ /dev/null @@ -1,149 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import json -import os -import shutil -import signal -import subprocess -import time - -""" -Utility functions to manage the vLLM server for online profiling. -Main functions are setup_server(), start_server(), and kill_server(). -""" - - -def wait_for_server(port: int) -> bool: - timeout = 1200 # 20 mins - start_time = time.time() - while time.time() - start_time < timeout: - try: - subprocess.run( - ["curl", "-X", "POST", f"localhost:{port}/v1/completions"], check=True - ) - return True - except subprocess.CalledProcessError: - time.sleep(10) # wait for 10 seconds before retrying - return False - - -def kill_gpu_processes(port: int): - subprocess.run(["ps", "-aux"]) - subprocess.run([f"lsof -t -i:{port} | xargs -r kill -9"], shell=True) - - # Use ps to list all Python processes and grep to exclude the specific one - command = ["ps", "aux"] - ps_output = subprocess.check_output(command, text=True) - - # Do not kill this process - filename = os.path.basename(__file__) - - # REMOVE - print(f"filename to exclude: {filename}") - - pids_to_kill = [] - for line in ps_output.split("\n"): - if "python3" in line and filename not in line: - pid = line.split()[1] - pids_to_kill.append(pid) - - # Kill other processes - for pid in pids_to_kill: - subprocess.run(["kill", "-9", pid]) - - wait_for_gpu_memory_to_clear() - - # subprocess.run(["rm", "-rf", "~/.config/vllm"]) - - -def wait_for_gpu_memory_to_clear(): - # Wait until all GPUs have memory usage < 1000 MB - if shutil.which("nvidia-smi"): - while True: - # Get GPU memory usage for all GPUs - memory_usage = subprocess.check_output( - [ - "nvidia-smi", - "--query-gpu=memory.used", - "--format=csv,noheader,nounits", - ], - text=True, - ) - # Split the output into individual GPU memory usage values - gpu_memory_usage = [int(x) for x in memory_usage.strip().split("\n")] - # Check if any GPU has memory usage >= 1000 MB - if all(usage < 1000 for usage in gpu_memory_usage): - break - time.sleep(1) - elif shutil.which("amd-smi"): - while True: - memory_usage = subprocess.check_output( - ["amd-smi", "metric", "-g", "0"], text=True - ) - used_vram = int(memory_usage.split("USED_VRAM")[1].split()[0]) - if used_vram < 1000: - break - time.sleep(1) - - -def setup_server(): - # install dependencies - dependencies = ["lsof", "curl", "pgrep"] - for dep in dependencies: - if not shutil.which(dep): - subprocess.run(["apt-get", "update"]) - subprocess.run(["apt-get", "install", "-y", dep]) - - -def start_server( - port: int, - target_model_dir: str, - spec_config: dict | None, - tp: int, - max_vllm_bs: int, - dry_run: bool = False, -) -> subprocess.Popen | None: - # NOTE: no Prompt Caching, but enabled chunked prefill - server_command = f"""VLLM_USE_V1=1 vllm serve {target_model_dir} \ - --disable-log-requests --port {port} \ - --gpu_memory_utilization 0.95 \ - --max_num_seqs {max_vllm_bs} \ - --tensor_parallel_size {tp} \ - --enable-chunked-prefill \ - --no-enable-prefix-caching """ - - if spec_config: - speculative_config_json_serialized = json.dumps(spec_config).replace('"', '\\"') - server_command += ( - f'--speculative_config "{speculative_config_json_serialized}" ' - ) - - print(f"Server command: {server_command}") - - # start vllm server - if not dry_run: - server_process = subprocess.Popen( - server_command, shell=True, preexec_fn=os.setsid - ) - - if wait_for_server(port): - print("vllm server is up and running.") - else: - print("vllm failed to start within the timeout period.") - server_process.kill() - - return server_process - else: - return None - - -def kill_server(port, server_process): - if server_process: - os.killpg(os.getpgid(server_process.pid), signal.SIGKILL) - - wait_for_gpu_memory_to_clear() - - # Clean vLLM config - config_path = os.path.expanduser("~/.config/vllm") - if os.path.exists(config_path): - subprocess.run(["rm", "-rf", config_path]) From 116e76b5f78ff176f834811fbe93d5b1f33c78c7 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Sun, 8 Feb 2026 06:11:40 +0000 Subject: [PATCH 15/39] add timeout Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/spec_decode/dynamic/generate_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/spec_decode/dynamic/generate_config.py b/vllm/v1/spec_decode/dynamic/generate_config.py index 3518af6a7655..36953e5396af 100644 --- a/vllm/v1/spec_decode/dynamic/generate_config.py +++ b/vllm/v1/spec_decode/dynamic/generate_config.py @@ -166,6 +166,7 @@ def run_profiling_sweep(args): dry_run=False, resume=None, link_vars=[], + server_ready_timeout=600, ) result_df = run_main(sweep_args) From d70bd1d5f2acc5de9c3854783f0a843469ed972f Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Sun, 8 Feb 2026 06:29:42 +0000 Subject: [PATCH 16/39] fix Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/core/sched/scheduler.py | 1 + vllm/v1/spec_decode/dynamic/generate_config.py | 9 +-------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 9bea6036ea15..986c740617df 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -213,6 +213,7 @@ def __init__( speculative_config = vllm_config.speculative_config self.use_eagle = False self.num_spec_tokens = self.num_lookahead_tokens = 0 + self.spec_decoding_stats_all = None if speculative_config: self.num_spec_tokens = speculative_config.num_speculative_tokens self.spec_decoding_stats_all = SpecDecodingStats.new(self.num_spec_tokens) diff --git a/vllm/v1/spec_decode/dynamic/generate_config.py b/vllm/v1/spec_decode/dynamic/generate_config.py index 36953e5396af..b23ebba05098 100644 --- a/vllm/v1/spec_decode/dynamic/generate_config.py +++ b/vllm/v1/spec_decode/dynamic/generate_config.py @@ -106,14 +106,7 @@ def parse_itl_from_dataframe(result_df): def run_profiling_sweep(args): - """Run profiling benchmarks using vllm bench sweep serve. - - This replaces the custom profiling_client/profiling_server by leveraging - the existing vllm bench sweep serve utility which handles: - - Server lifecycle management (start, wait-for-ready, stop) - - Cartesian product of serve_params x bench_params - - Result saving and aggregation - """ + """Run profiling benchmarks using vllm bench sweep serve.""" # Base serve command (static params shared across all serve configs) serve_cmd = [ "vllm", "serve", args.model_dir, From 24304d54884eeab113018a5fcdfab12d16d6389f Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Mon, 9 Feb 2026 00:01:19 +0000 Subject: [PATCH 17/39] reduce loc in favor of #34105 Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/spec_decode/offline.py | 240 --------------------------------- 1 file changed, 240 deletions(-) delete mode 100644 vllm/v1/spec_decode/offline.py diff --git a/vllm/v1/spec_decode/offline.py b/vllm/v1/spec_decode/offline.py deleted file mode 100644 index 9e9d30ed94f3..000000000000 --- a/vllm/v1/spec_decode/offline.py +++ /dev/null @@ -1,240 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from transformers import AutoTokenizer - -from vllm import LLM, SamplingParams -from vllm.benchmarks.datasets import add_dataset_parser, get_samples -from vllm.inputs import TokensPrompt -from vllm.v1.metrics.reader import Counter, Vector - -try: - from vllm.utils.argparse_utils import FlexibleArgumentParser -except ImportError: - from argparse import ArgumentParser as FlexibleArgumentParser - - -QUESTION = "What is the content of each image?" -IMAGE_URLS = [ - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/duck.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/lion.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/flycatcher.jpeg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/somefish.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/starfish.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/snail.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/thistle.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/husky.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/orangetabbycat.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/guineapig.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/rabbit.jpg", - "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/horsepony.jpg", -] - - -def get_custom_mm_prompts(num_prompts): - prompts = [] - for url in IMAGE_URLS: - prompts.append( - [ - {"type": "image_url", "image_url": {"url": url}}, - {"type": "text", "text": QUESTION}, - ] - ) - if num_prompts > len(IMAGE_URLS): - prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1) - - return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]] - - -def parse_args(): - parser = FlexibleArgumentParser() - add_dataset_parser(parser) - parser.add_argument("--test", action="store_true") - parser.add_argument( - "--method", - type=str, - default="eagle", - choices=["ngram", "eagle", "eagle3", "mtp"], - ) - parser.add_argument("--num-spec-tokens", type=int, default=2) - parser.add_argument("--prompt-lookup-max", type=int, default=5) - parser.add_argument("--prompt-lookup-min", type=int, default=2) - parser.add_argument("--tp", type=int, default=1) - parser.add_argument("--enforce-eager", action="store_true") - parser.add_argument("--enable-chunked-prefill", action="store_true") - parser.add_argument("--max-model-len", type=int, default=16384) - parser.add_argument("--temp", type=float, default=0) - parser.add_argument("--top-p", type=float, default=1.0) - parser.add_argument("--top-k", type=int, default=-1) - parser.add_argument("--print-output", action="store_true") - parser.add_argument("--output-len", type=int, default=256) - parser.add_argument("--model-dir", type=str, default=None) - parser.add_argument("--eagle-dir", type=str, default=None) - parser.add_argument("--custom-mm-prompts", action="store_true") - return parser.parse_args() - - -def main(args): - args.endpoint_type = "openai-chat" - - model_dir = args.model_dir - if args.model_dir is None: - if args.custom_mm_prompts: - raise ValueError( - "custom_mm_prompts requires mm based models" - "default llama3.1-8b-instruct is not mm based" - "please specify model_dir to give a mm based model" - ) - model_dir = "meta-llama/Llama-3.1-8B-Instruct" - tokenizer = AutoTokenizer.from_pretrained(model_dir) - args.custom_skip_chat_template = True - - if not args.custom_mm_prompts: - prompts = get_samples(args, tokenizer) - # add_special_tokens is False to avoid adding bos twice - # when using chat templates - prompt_ids = [ - tokenizer.encode(prompt.prompt, add_special_tokens=False) - for prompt in prompts - ] - else: - prompts = get_custom_mm_prompts(args.num_prompts) - - if args.method == "eagle" or args.method == "eagle3": - eagle_dir = args.eagle_dir - if args.method == "eagle" and eagle_dir is None: - eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" - - elif args.method == "eagle3" and eagle_dir is None: - eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" - speculative_config = { - "method": args.method, - "model": eagle_dir, - "num_speculative_tokens": args.num_spec_tokens, - } - elif args.method == "ngram": - speculative_config = { - "method": "ngram", - "num_speculative_tokens": args.num_spec_tokens, - "prompt_lookup_max": args.prompt_lookup_max, - "prompt_lookup_min": args.prompt_lookup_min, - } - elif args.method == "mtp": - speculative_config = { - "method": "mtp", - "num_speculative_tokens": args.num_spec_tokens, - } - else: - raise ValueError(f"unknown method: {args.method}") - - llm = LLM( - model=model_dir, - trust_remote_code=True, - tensor_parallel_size=args.tp, - enable_chunked_prefill=args.enable_chunked_prefill, - enforce_eager=args.enforce_eager, - gpu_memory_utilization=0.9, - speculative_config=speculative_config, - disable_log_stats=False, - max_model_len=args.max_model_len, - limit_mm_per_prompt={"image": 5}, - disable_chunked_mm_input=True, - ) - - sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) - if not args.custom_mm_prompts: - outputs = llm.generate( - [TokensPrompt(prompt_token_ids=x) for x in prompt_ids], - sampling_params=sampling_params, - ) - else: - outputs = llm.chat(prompts, sampling_params=sampling_params) - - # print the generated text - if args.print_output: - for output in outputs: - print("-" * 50) - print(f"prompt: {output.prompt}") - print(f"generated text: {output.outputs[0].text}") - print("-" * 50) - - metrics = llm.get_metrics() - - total_num_output_tokens = sum( - len(output.outputs[0].token_ids) for output in outputs - ) - num_drafts = 0 - num_draft_tokens = 0 - num_accepted_tokens = 0 - acceptance_counts = [0] * args.num_spec_tokens - for metric in metrics: - if metric.name == "vllm:spec_decode_num_drafts": - assert isinstance(metric, Counter) - num_drafts += metric.value - elif metric.name == "vllm:spec_decode_num_draft_tokens": - assert isinstance(metric, Counter) - num_draft_tokens += metric.value - elif metric.name == "vllm:spec_decode_num_accepted_tokens": - assert isinstance(metric, Counter) - num_accepted_tokens += metric.value - elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": - assert isinstance(metric, Vector) - for pos in range(len(metric.values)): - acceptance_counts[pos] += metric.values[pos] - - print("-" * 50) - print(f"total_num_output_tokens: {total_num_output_tokens}") - print(f"num_drafts: {num_drafts}") - print(f"num_draft_tokens: {num_draft_tokens}") - print(f"num_accepted_tokens: {num_accepted_tokens}") - acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1 - print(f"mean acceptance length: {acceptance_length:.2f}") - print("-" * 50) - - # print acceptance at each token position - acceptance_rate_per_pos = [] - for i in range(len(acceptance_counts)): - acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0 - print(f"acceptance at token {i}: {acceptance_rate:.2f}") - acceptance_rate_per_pos.append(acceptance_rate) - - return acceptance_length, acceptance_rate_per_pos - - -def entrypoint(): - args = parse_args() - acceptance_length, acceptance_rate_per_pos = main(args) - - if args.test: - # takes ~30s to run on 1xH100 - assert args.method in ["eagle", "eagle3"] - assert args.tp == 1 - assert args.num_spec_tokens == 3 - assert args.dataset_name == "hf" - assert args.dataset_path == "philschmid/mt-bench" - assert args.num_prompts == 80 - assert args.temp == 0 - assert args.top_p == 1.0 - assert args.top_k == -1 - assert args.enable_chunked_prefill - - # check acceptance length is within 2% of expected value - rtol = 0.02 - expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811 - - assert ( - acceptance_length <= (1 + rtol) * expected_acceptance_length - and acceptance_length >= (1 - rtol) * expected_acceptance_length - ), ( - f"acceptance_length {acceptance_length} is not " - f"within {rtol * 100}% of {expected_acceptance_length}" - ) - - print( - f"Test passed! Expected AL: " - f"{expected_acceptance_length}, got {acceptance_length}" - ) - - -if __name__ == "__main__": - entrypoint() \ No newline at end of file From 8fb86b1c9d2af8d7654903bf43ce7826ed54ef60 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Mon, 9 Feb 2026 00:09:04 +0000 Subject: [PATCH 18/39] remove test from dynamic manager main() Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/spec_decode/dynamic/manager.py | 96 +++----------------------- 1 file changed, 11 insertions(+), 85 deletions(-) diff --git a/vllm/v1/spec_decode/dynamic/manager.py b/vllm/v1/spec_decode/dynamic/manager.py index 9d1ccf6b0ac1..37a931ad9fcc 100644 --- a/vllm/v1/spec_decode/dynamic/manager.py +++ b/vllm/v1/spec_decode/dynamic/manager.py @@ -1,22 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.config.speculative import DynamicSpeculativeConfig import json -from vllm.utils.argparse_utils import FlexibleArgumentParser - -_DYNAMIC_STATS_TEST = DynamicSpeculativeConfig( - max_num_speculative_tokens=7, - acceptance_rate_per_pos=[0.68, 0.39, 0.20, 0.10, 0.06, 0.03, 0.02], # E1 - batch_stats={ - 1: {0: 6.87, 1: 7.97, 3: 9.41, 4: 9.91, 5: 10.8, 7: 12.29,}, - 4: {0: 6.87, 1: 7.97, 3: 9.41, 4: 9.91, 5: 10.8, 7: 12.29,}, - 16: {0: 7.3, 1: 8.39, 3: 9.95, 4: 10.8, 5: 11.59, 7: 13.11,}, - 32: {0: 7.64, 1: 8.97, 3: 10.78, 4: 11.79, 5: 12.81, 7: 14.86,}, - 64: {0: 8.53, 1: 10.44, 3: 13.16, 4: 15.7, 5: 17.54, 7: 120.57,}, - 128: {0: 8.53, 1: 15.44, 3: 25.16, 4: 30.7, 5: 37.54, 7: 220.57,}, # fake - }, -) +from vllm.config.speculative import DynamicSpeculativeConfig +from vllm.utils.argparse_utils import FlexibleArgumentParser class DynamicSpeculativeDecodingManager: @@ -31,7 +18,7 @@ def __init__( dynamic_config: DynamicSpeculativeConfig | None, vllm_max_batch_size: int, vllm_num_speculative_tokens: int, - warmup_steps: int = 10, # TODO: make this configurable + warmup_steps: int = 10, # TODO: make this configurable ): self.dynamic_config = dynamic_config self.vllm_max_batch_size = vllm_max_batch_size @@ -47,7 +34,6 @@ def __init__( <= self.dynamic_config.max_num_speculative_tokens ), "vllm_num_speculative_tokens must be <= max_num_speculative_tokens" - # if self.dynamic_config.is_online: assert self.dynamic_config.max_num_speculative_tokens == len( self.dynamic_config.acceptance_rate_per_pos ), ( @@ -83,19 +69,17 @@ def __init__( def step(self, spec_decoding_stats_all, batch_size: int) -> int: self.steps += 1 if self.should_update(): - acceptance_rate_per_pos = self.compute_acceptance_rate_per_pos(spec_decoding_stats_all) - self.update_acceptance_rate_per_pos( - acceptance_rate_per_pos - ) - - optimal_num_speculative_tokens = ( - self.get_optimal_num_speculative_tokens( - batch_size + acceptance_rate_per_pos = self.compute_acceptance_rate_per_pos( + spec_decoding_stats_all ) + self.update_acceptance_rate_per_pos(acceptance_rate_per_pos) + + optimal_num_speculative_tokens = self.get_optimal_num_speculative_tokens( + batch_size ) return optimal_num_speculative_tokens - + def compute_acceptance_rate_per_pos(self, spec_decoding_stats_all) -> list[float]: acceptance_rate_per_pos = [] for i in range(self.vllm_num_speculative_tokens): @@ -115,7 +99,6 @@ def should_update(self) -> bool: # making this a separate function for easier overriding or extension return self.steps > self.warmup_steps - def get_optimal_num_speculative_tokens(self, batch_size: int) -> int: assert batch_size > 0, "batch_size must be > 0" assert batch_size <= self.vllm_max_batch_size, ( @@ -129,21 +112,15 @@ def update_optimal_num_speculative_tokens(self): for bs in range(1, self.vllm_max_batch_size + 1) } - def update_acceptance_rate_per_pos( - self, acceptance_rate_per_pos: list[float] - ): + def update_acceptance_rate_per_pos(self, acceptance_rate_per_pos: list[float]): self.dynamic_config.acceptance_rate_per_pos = acceptance_rate_per_pos self.update_optimal_num_speculative_tokens() - def _get_batch_stats(self, batch_size: int) -> dict: # import pdb; pdb.set_trace() if batch_size not in self.batch_stats: # find the nearest batch size smaller and bigger than the given batch size # and return the weighted avg of their stats - # print( - # f"Finding batch stats for batch_size: {batch_size} in self.available_batch_sizes: {self.available_batch_sizes}" - # ) smaller_bs = [bs for bs in self.available_batch_sizes if bs < batch_size] smaller_bs = ( @@ -154,11 +131,6 @@ def _get_batch_stats(self, batch_size: int) -> dict: min(larger_bs) if len(larger_bs) else self.available_batch_sizes[-1] ) - # REMOVE - # print( - # f"smaller_bs: {smaller_bs}, larger_bs: {larger_bs}, batch_size: {batch_size}" - # ) - smaller_bs_stat = self.batch_stats[smaller_bs] larger_bs_stat = self.batch_stats[larger_bs] @@ -167,9 +139,6 @@ def _get_batch_stats(self, batch_size: int) -> dict: ratio = (batch_size - smaller_bs) / (larger_bs - smaller_bs) - # REMOVE - # print(f"ratio: {ratio}") - avg_stat: dict[int, float] = {} for k in smaller_bs_stat: avg_stat[k] = smaller_bs_stat[k] + ratio * ( @@ -187,9 +156,6 @@ def _get_itl(self, batch_stats, num_drafts: int) -> float: lower_num_draft = max(k for k in batch_stats.keys() if k < num_drafts) upper_num_draft = min(k for k in batch_stats.keys() if k > num_drafts) - # REMOVE - # print(f"lower_num_draft: {lower_num_draft}, upper_num_draft: {upper_num_draft}, num_drafts: {num_drafts}") - ratio = (num_drafts - lower_num_draft) / (upper_num_draft - lower_num_draft) lower_itl = batch_stats[lower_num_draft] upper_itl = batch_stats[upper_num_draft] @@ -207,44 +173,4 @@ def _compute_optimal_num_speculative_tokens(self, batch_size: int) -> int: max_goodput = curr_goodput chosen_num_drafts = num_drafts - # REMOVE - # print( - # f"num_drafts: {num_drafts}, al: {curr_al}, itl: {curr_itl}, goodput: {curr_goodput}" - # ) - return chosen_num_drafts - - -# python3 vllm/v1/spec_decode/dynamic/manager.py --dynamic-config-path log/dynamic_sd_test_2/tp-1_temp-0.0_top_p-1.0_top_k--1/philschmid/mt-bench/dynamic_speculative_config.json -if __name__ == "__main__": - parser = FlexibleArgumentParser() - parser.add_argument( - "--dynamic-config-path", - type=str, - default=None, - help="Path to the dynamic speculative decoding config json file.", - ) - args = parser.parse_args() - - MAX_TEST_BS = 128 - if args.dynamic_config_path: - with open(args.dynamic_config_path) as f: - data = json.load(f) - - dynamic_config = DynamicSpeculativeConfig.model_validate(data) - dynamic_sd = DynamicSpeculativeDecodingManager( - dynamic_config=dynamic_config, - vllm_max_batch_size=max(dynamic_config.batch_stats.keys()), - vllm_num_speculative_tokens=dynamic_config.max_num_speculative_tokens, - ) - else: - dynamic_sd = DynamicSpeculativeDecodingManager( - dynamic_config=_DYNAMIC_STATS_TEST, - vllm_max_batch_size=MAX_TEST_BS, - vllm_num_speculative_tokens=7, - ) - for i in range(1, dynamic_sd.vllm_max_batch_size + 1, 4): - print("\n====================================") - print( - f"bs: {i}, optimal num drafts: {dynamic_sd.get_optimal_num_speculative_tokens(i)}" - ) From 11d43a5323d80b2b7eb2c85919cc7fc694318d98 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Mon, 9 Feb 2026 00:21:55 +0000 Subject: [PATCH 19/39] remove comment and fix lint Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- examples/offline_inference/spec_decode.py | 2 +- vllm/config/speculative.py | 5 +- .../v1/spec_decode/dynamic/generate_config.py | 150 +++++++++++------- vllm/v1/spec_decode/dynamic/manager.py | 16 +- vllm/v1/spec_decode/eagle.py | 2 +- vllm/v1/spec_decode/metrics.py | 3 - vllm/v1/worker/gpu_model_runner.py | 6 +- 7 files changed, 105 insertions(+), 79 deletions(-) diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 6256e4ac3f9c..d8c5ece4fa66 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -260,4 +260,4 @@ def main(args): print( f"Test passed! Expected AL: " f"{expected_acceptance_length}, got {acceptance_length}" - ) \ No newline at end of file + ) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 212b0461f642..92aae82f4b2e 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -4,7 +4,7 @@ import ast from typing import TYPE_CHECKING, Any, Literal, get_args -from pydantic import Field, SkipValidation, model_validator, BaseModel +from pydantic import BaseModel, Field, SkipValidation, model_validator from typing_extensions import Self from vllm.config.model import ModelConfig @@ -169,8 +169,8 @@ class SpeculativeConfig: """The parallel configuration for the target model.""" # dynamic speculative decoding control - """Path to config file for dynamic speculative decoding, if provided.""" dynamic_config_path: str | None = None + """Path to config file for dynamic speculative decoding, if provided.""" # params generated in the post-init stage draft_model_config: SkipValidation[ModelConfig] = None # type: ignore @@ -565,6 +565,7 @@ def __post_init__(self): # load DynamicSpeculativeConfig: maybe use get_hf_file_to_dict() later if self.dynamic_config_path is not None: import json + with open(self.dynamic_config_path) as f: data = json.load(f) diff --git a/vllm/v1/spec_decode/dynamic/generate_config.py b/vllm/v1/spec_decode/dynamic/generate_config.py index b23ebba05098..2688aed89873 100644 --- a/vllm/v1/spec_decode/dynamic/generate_config.py +++ b/vllm/v1/spec_decode/dynamic/generate_config.py @@ -1,19 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json import pprint import time from pathlib import Path +from vllm.v1.spec_decode.offline import main as spec_decode_main + +from vllm.benchmarks.datasets import add_dataset_parser from vllm.benchmarks.sweep.param_sweep import ParameterSweep from vllm.benchmarks.sweep.serve import SweepServeArgs, run_main from vllm.config.speculative import DynamicSpeculativeConfig -from vllm.v1.spec_decode.offline import main as spec_decode_main from vllm.utils.argparse_utils import FlexibleArgumentParser -from vllm.benchmarks.datasets import add_dataset_parser -def build_serve_params(method, draft_dir, tp, - num_speculative_tokens_list, - prompt_lookup_max, prompt_lookup_min): +def build_serve_params( + method, + draft_dir, + tp, + num_speculative_tokens_list, + prompt_lookup_max, + prompt_lookup_min, +): """Build serve parameter sweep for vanilla + speculative decode configs. Each entry becomes a separate server configuration in the sweep. @@ -27,35 +35,41 @@ def build_serve_params(method, draft_dir, tp, # Speculative decoding configs with varying num_speculative_tokens if method == "ngram": for k in num_speculative_tokens_list: - records.append({ - "_benchmark_name": f"ngram-k-{k}", - "speculative_config": { - "method": "ngram", - "num_speculative_tokens": k, - "prompt_lookup_max": prompt_lookup_max, - "prompt_lookup_min": prompt_lookup_min, - }, - }) + records.append( + { + "_benchmark_name": f"ngram-k-{k}", + "speculative_config": { + "method": "ngram", + "num_speculative_tokens": k, + "prompt_lookup_max": prompt_lookup_max, + "prompt_lookup_min": prompt_lookup_min, + }, + } + ) elif method in ("eagle", "eagle3"): for k in num_speculative_tokens_list: - records.append({ - "_benchmark_name": f"{method}-k-{k}", - "speculative_config": { - "method": method, - "model": draft_dir, - "num_speculative_tokens": k, - "draft_tensor_parallel_size": tp, - }, - }) + records.append( + { + "_benchmark_name": f"{method}-k-{k}", + "speculative_config": { + "method": method, + "model": draft_dir, + "num_speculative_tokens": k, + "draft_tensor_parallel_size": tp, + }, + } + ) elif method == "mtp": for k in num_speculative_tokens_list: - records.append({ - "_benchmark_name": f"mtp-k-{k}", - "speculative_config": { - "method": "mtp", - "num_speculative_tokens": k, - }, - }) + records.append( + { + "_benchmark_name": f"mtp-k-{k}", + "speculative_config": { + "method": "mtp", + "num_speculative_tokens": k, + }, + } + ) return ParameterSweep.from_records(records) @@ -68,11 +82,13 @@ def build_bench_params(batch_size_list, num_batches): records = [] for bs in batch_size_list: num_prompts = num_batches * bs - records.append({ - "_benchmark_name": f"bs-{bs}", - "max_concurrency": bs, - "num_prompts": num_prompts, - }) + records.append( + { + "_benchmark_name": f"bs-{bs}", + "max_concurrency": bs, + "num_prompts": num_prompts, + } + ) return ParameterSweep.from_records(records) @@ -109,23 +125,35 @@ def run_profiling_sweep(args): """Run profiling benchmarks using vllm bench sweep serve.""" # Base serve command (static params shared across all serve configs) serve_cmd = [ - "vllm", "serve", args.model_dir, + "vllm", + "serve", + args.model_dir, "--disable-log-requests", - "--gpu-memory-utilization", "0.95", - "--max-num-seqs", str(args.max_vllm_batch_size), - "--tensor-parallel-size", str(args.tp), + "--gpu-memory-utilization", + "0.95", + "--max-num-seqs", + str(args.max_vllm_batch_size), + "--tensor-parallel-size", + str(args.tp), "--enable-chunked-prefill", "--no-enable-prefix-caching", ] # Base bench command (static params shared across all bench configs) bench_cmd = [ - "vllm", "bench", "serve", - "--model", args.model_dir, - "--backend", "openai-chat", - "--endpoint", "/v1/chat/completions", - "--dataset-name", args.dataset_name, - "--dataset-path", args.dataset_path, + "vllm", + "bench", + "serve", + "--model", + args.model_dir, + "--backend", + "openai-chat", + "--endpoint", + "/v1/chat/completions", + "--dataset-name", + args.dataset_name, + "--dataset-path", + args.dataset_path, f"--temperature={args.temp}", f"--top-p={args.top_p}", f"--top-k={args.top_k}", @@ -172,15 +200,17 @@ def main(): parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--draft-dir", type=str, default=None) - parser.add_argument("--method", type=str, default="eagle", - choices=["ngram", "eagle", "eagle3", "mtp"]) parser.add_argument( - "--num-speculative-tokens-list", nargs="*", type=int, - default=[1, 3, 5] + "--method", + type=str, + default="eagle", + choices=["ngram", "eagle", "eagle3", "mtp"], ) parser.add_argument( - "--batch-size-list", nargs="*", type=int, - default=[1, 4, 16, 64, 256] + "--num-speculative-tokens-list", nargs="*", type=int, default=[1, 3, 5] + ) + parser.add_argument( + "--batch-size-list", nargs="*", type=int, default=[1, 4, 16, 64, 256] ) parser.add_argument( "--max-vllm-batch-size", @@ -196,8 +226,12 @@ def main(): parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-k", type=int, default=-1) parser.add_argument("--output-len", type=int, default=256) - parser.add_argument("--num-batches", type=int, default=20, - help="Number of batches to run for each benchmark.") + parser.add_argument( + "--num-batches", + type=int, + default=20, + help="Number of batches to run for each benchmark.", + ) parser.add_argument("--custom-mm-prompts", action="store_true") args = parser.parse_args() @@ -206,9 +240,11 @@ def main(): args.print_output = False args.num_spec_tokens = max(args.num_speculative_tokens_list) args.eagle_dir = args.draft_dir - args.result_dir = (f"{args.result_dir}/tp-{args.tp}_temp-{args.temp}" - f"_top_p-{args.top_p}_top_k-{args.top_k}" - f"/{args.dataset_path}/") + args.result_dir = ( + f"{args.result_dir}/tp-{args.tp}_temp-{args.temp}" + f"_top_p-{args.top_p}_top_k-{args.top_k}" + f"/{args.dataset_path}/" + ) assert args.max_vllm_batch_size == max(args.batch_size_list), ( "max_vllm_batch_size must be equal to max of batch_size_list" @@ -291,4 +327,4 @@ def main(): --result-dir './log/dynamic_sd_test_short' """ if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/vllm/v1/spec_decode/dynamic/manager.py b/vllm/v1/spec_decode/dynamic/manager.py index 37a931ad9fcc..4a55298c0c77 100644 --- a/vllm/v1/spec_decode/dynamic/manager.py +++ b/vllm/v1/spec_decode/dynamic/manager.py @@ -1,9 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import json from vllm.config.speculative import DynamicSpeculativeConfig -from vllm.utils.argparse_utils import FlexibleArgumentParser class DynamicSpeculativeDecodingManager: @@ -15,7 +13,7 @@ class DynamicSpeculativeDecodingManager: def __init__( self, - dynamic_config: DynamicSpeculativeConfig | None, + dynamic_config: DynamicSpeculativeConfig, vllm_max_batch_size: int, vllm_num_speculative_tokens: int, warmup_steps: int = 10, # TODO: make this configurable @@ -36,9 +34,7 @@ def __init__( assert self.dynamic_config.max_num_speculative_tokens == len( self.dynamic_config.acceptance_rate_per_pos - ), ( - "max_num_speculative_tokens must be equal to the length of acceptance_rate_per_pos" - ) + ), "max_num_speculative_tokens != len(acceptance_rate_per_pos)" assert self.dynamic_config.max_num_speculative_tokens > 0, ( "max_num_speculative_tokens must be > 0" ) @@ -46,10 +42,10 @@ def __init__( "all acceptance_rate_per_pos values must be in (0, 1)" ) assert 1 in self.dynamic_config.batch_stats, ( - f"batch size 1 must be available, found: {self.dynamic_config.batch_stats.keys()}" + f"BS 1 not found in {self.dynamic_config.batch_stats.keys()}" ) assert vllm_max_batch_size in self.dynamic_config.batch_stats, ( - f"vllm max_num_seqs {vllm_max_batch_size} must be available, found: {self.dynamic_config.batch_stats.keys()}" + f"max BS not found in {self.dynamic_config.batch_stats.keys()}" ) for bs in self.available_batch_sizes: @@ -153,8 +149,8 @@ def _get_itl(self, batch_stats, num_drafts: int) -> float: if num_drafts in batch_stats: return batch_stats[num_drafts] else: - lower_num_draft = max(k for k in batch_stats.keys() if k < num_drafts) - upper_num_draft = min(k for k in batch_stats.keys() if k > num_drafts) + lower_num_draft = max(k for k in batch_stats if k < num_drafts) + upper_num_draft = min(k for k in batch_stats if k > num_drafts) ratio = (num_drafts - lower_num_draft) / (upper_num_draft - lower_num_draft) lower_itl = batch_stats[lower_num_draft] diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 32ab4895670b..7343925620df 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -395,7 +395,7 @@ def propose( # Use optimal num speculative tokens if provided if optimal_num_speculative_tokens is not None: self.num_speculative_tokens = optimal_num_speculative_tokens - + batch_size = common_attn_metadata.batch_size() if self.method == "eagle3": diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index 3dfc62359a08..ee6cfa6e11fb 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -47,9 +47,6 @@ def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int): for i in range(num_draft_tokens): self.num_draft_tokens_per_pos[i] += 1 - # REMOVE - # print(f"self.num_drafts: {self.num_drafts}, num_draft_tokens: {num_draft_tokens}, num_accepted_tokens: {num_accepted_tokens}") - class SpecDecodingLogging: """Aggregate and log spec decoding metrics. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2e97a4621fba..9ee1f6e63eba 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -153,8 +153,8 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler -from vllm.v1.spec_decode.dynamic.manager import DynamicSpeculativeDecodingManager from vllm.v1.spec_decode.draft_model import DraftModelProposer +from vllm.v1.spec_decode.dynamic.manager import DynamicSpeculativeDecodingManager from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -500,10 +500,6 @@ def __init__( else: self.dynamic_sd_manager = None - # REMOVE - if self.dynamic_sd_manager: - print(f"_optimal_num_speculative_tokens: {self.dynamic_sd_manager._optimal_num_speculative_tokens}") - # Request states. self.requests: dict[str, CachedRequestState] = {} # NOTE(rob): num_prompt_logprobs only includes reqs From 5df4999b79549c7fcf721a6b2fabe33567c7d407 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Tue, 17 Feb 2026 04:48:12 +0000 Subject: [PATCH 20/39] fix mypy Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/spec_decode/dynamic/generate_config.py | 4 ++-- vllm/v1/spec_decode/dynamic/manager.py | 18 +++++++++++++----- vllm/v1/worker/gpu_model_runner.py | 7 ++++--- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/vllm/v1/spec_decode/dynamic/generate_config.py b/vllm/v1/spec_decode/dynamic/generate_config.py index 2688aed89873..c3187980d548 100644 --- a/vllm/v1/spec_decode/dynamic/generate_config.py +++ b/vllm/v1/spec_decode/dynamic/generate_config.py @@ -27,7 +27,7 @@ def build_serve_params( Each entry becomes a separate server configuration in the sweep. The sweep framework starts/stops the server for each serve config. """ - records = [] + records: list[dict[str, object]] = [] # Vanilla config (no speculative decoding) records.append({"_benchmark_name": "vanilla"}) @@ -99,7 +99,7 @@ def parse_itl_from_dataframe(result_df): batch_stats: dict of {batch_size: {num_drafts: median_itl_ms}} where num_drafts=0 corresponds to vanilla (no speculation). """ - batch_stats = {} + batch_stats: dict[int, dict[int, float]] = {} for _, row in result_df.iterrows(): bs = int(row["max_concurrency"]) diff --git a/vllm/v1/spec_decode/dynamic/manager.py b/vllm/v1/spec_decode/dynamic/manager.py index 4a55298c0c77..3ec93d847534 100644 --- a/vllm/v1/spec_decode/dynamic/manager.py +++ b/vllm/v1/spec_decode/dynamic/manager.py @@ -118,13 +118,21 @@ def _get_batch_stats(self, batch_size: int) -> dict: # find the nearest batch size smaller and bigger than the given batch size # and return the weighted avg of their stats - smaller_bs = [bs for bs in self.available_batch_sizes if bs < batch_size] + smaller_bs_list = [ + bs for bs in self.available_batch_sizes if bs < batch_size + ] smaller_bs = ( - max(smaller_bs) if len(smaller_bs) else self.available_batch_sizes[0] + max(smaller_bs_list) + if len(smaller_bs_list) + else self.available_batch_sizes[0] ) - larger_bs = [bs for bs in self.available_batch_sizes if bs > batch_size] + larger_bs_list = [ + bs for bs in self.available_batch_sizes if bs > batch_size + ] larger_bs = ( - min(larger_bs) if len(larger_bs) else self.available_batch_sizes[-1] + min(larger_bs_list) + if len(larger_bs_list) + else self.available_batch_sizes[-1] ) smaller_bs_stat = self.batch_stats[smaller_bs] @@ -160,7 +168,7 @@ def _get_itl(self, batch_stats, num_drafts: int) -> float: def _compute_optimal_num_speculative_tokens(self, batch_size: int) -> int: batch_stats = self._get_batch_stats(batch_size) - max_goodput = -1 + max_goodput = -1.0 for num_drafts in range(self.dynamic_config.max_num_speculative_tokens + 1): curr_al = 1 + sum(self.dynamic_config.acceptance_rate_per_pos[:num_drafts]) curr_itl = self._get_itl(batch_stats, num_drafts) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9ee1f6e63eba..842c85a68594 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -491,14 +491,15 @@ def __init__( self.effective_drafter_max_model_len = self.max_model_len # setup Dynamic Speculative Decoding + self.dynamic_sd_manager: ( + DynamicSpeculativeDecodingManager | None + ) = None if self.speculative_config.dynamic_config: self.dynamic_sd_manager = DynamicSpeculativeDecodingManager( self.speculative_config.dynamic_config, self.vllm_config.scheduler_config.max_num_seqs, - self.vllm_config.speculative_config.num_speculative_tokens, + self.speculative_config.num_speculative_tokens, ) - else: - self.dynamic_sd_manager = None # Request states. self.requests: dict[str, CachedRequestState] = {} From e76ad8ed77cbde512739b0e2124d5d2087c9d944 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Tue, 17 Feb 2026 04:50:38 +0000 Subject: [PATCH 21/39] fix mypy Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/spec_decode/ngram_proposer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index e7dfec0c8d5f..d20a10ee4d3b 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -55,6 +55,7 @@ def __init__(self, vllm_config: VllmConfig): # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. self.propose( + None, [[]] * 1024, np.zeros(1024, dtype=np.int32), np.zeros((1024, self.max_model_len), dtype=np.int32), @@ -140,7 +141,7 @@ def propose( ) -> list[list[int]]: # Use optimal num speculative tokens if provided if optimal_num_speculative_tokens is not None: - self.num_speculative_tokens = optimal_num_speculative_tokens + self.k = optimal_num_speculative_tokens # find which requests need ngram proposals valid_ngram_requests = [] From c1e880ba4bd3f532e77fd7597d488283ae1c857e Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Tue, 17 Feb 2026 05:19:32 +0000 Subject: [PATCH 22/39] lint Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/spec_decode/dynamic/manager.py | 4 ++-- vllm/v1/worker/gpu_model_runner.py | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/v1/spec_decode/dynamic/manager.py b/vllm/v1/spec_decode/dynamic/manager.py index 3ec93d847534..e248c99c94ea 100644 --- a/vllm/v1/spec_decode/dynamic/manager.py +++ b/vllm/v1/spec_decode/dynamic/manager.py @@ -38,8 +38,8 @@ def __init__( assert self.dynamic_config.max_num_speculative_tokens > 0, ( "max_num_speculative_tokens must be > 0" ) - assert all(0 < a < 1 for a in self.dynamic_config.acceptance_rate_per_pos), ( - "all acceptance_rate_per_pos values must be in (0, 1)" + assert all(0.0 <= a <= 1.0 for a in self.dynamic_config.acceptance_rate_per_pos), ( + "all acceptance_rate_per_pos values must be in [0.0, 1.0]" ) assert 1 in self.dynamic_config.batch_stats, ( f"BS 1 not found in {self.dynamic_config.batch_stats.keys()}" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 842c85a68594..70b8c3290b6d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -491,9 +491,7 @@ def __init__( self.effective_drafter_max_model_len = self.max_model_len # setup Dynamic Speculative Decoding - self.dynamic_sd_manager: ( - DynamicSpeculativeDecodingManager | None - ) = None + self.dynamic_sd_manager: DynamicSpeculativeDecodingManager | None = None if self.speculative_config.dynamic_config: self.dynamic_sd_manager = DynamicSpeculativeDecodingManager( self.speculative_config.dynamic_config, From 72d3c6f5c561550c5654e11a55d14cb4d2a8ad52 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Tue, 17 Feb 2026 05:25:00 +0000 Subject: [PATCH 23/39] lint Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/spec_decode/dynamic/manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/spec_decode/dynamic/manager.py b/vllm/v1/spec_decode/dynamic/manager.py index e248c99c94ea..3278293fab2c 100644 --- a/vllm/v1/spec_decode/dynamic/manager.py +++ b/vllm/v1/spec_decode/dynamic/manager.py @@ -38,9 +38,9 @@ def __init__( assert self.dynamic_config.max_num_speculative_tokens > 0, ( "max_num_speculative_tokens must be > 0" ) - assert all(0.0 <= a <= 1.0 for a in self.dynamic_config.acceptance_rate_per_pos), ( - "all acceptance_rate_per_pos values must be in [0.0, 1.0]" - ) + assert all( + 0.0 <= a <= 1.0 for a in self.dynamic_config.acceptance_rate_per_pos + ), "all acceptance_rate_per_pos values must be in [0.0, 1.0]" assert 1 in self.dynamic_config.batch_stats, ( f"BS 1 not found in {self.dynamic_config.batch_stats.keys()}" ) From 534d2f2961ef93e00caf5d42dff3f5b1d4313bc1 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Fri, 13 Mar 2026 03:44:17 +0000 Subject: [PATCH 24/39] add AL computation to generate_config Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- .../v1/spec_decode/dynamic/generate_config.py | 138 +++++++++++++++++- vllm/v1/worker/gpu_model_runner.py | 7 + 2 files changed, 140 insertions(+), 5 deletions(-) diff --git a/vllm/v1/spec_decode/dynamic/generate_config.py b/vllm/v1/spec_decode/dynamic/generate_config.py index c3187980d548..d3e38691f361 100644 --- a/vllm/v1/spec_decode/dynamic/generate_config.py +++ b/vllm/v1/spec_decode/dynamic/generate_config.py @@ -5,14 +5,16 @@ import time from pathlib import Path -from vllm.v1.spec_decode.offline import main as spec_decode_main - from vllm.benchmarks.datasets import add_dataset_parser +from vllm.benchmarks.datasets import get_samples from vllm.benchmarks.sweep.param_sweep import ParameterSweep from vllm.benchmarks.sweep.serve import SweepServeArgs, run_main from vllm.config.speculative import DynamicSpeculativeConfig +from vllm import LLM +from transformers import AutoTokenizer +from vllm.sampling_params import SamplingParams from vllm.utils.argparse_utils import FlexibleArgumentParser - +from vllm.v1.metrics.reader import Counter, Vector def build_serve_params( method, @@ -128,7 +130,6 @@ def run_profiling_sweep(args): "vllm", "serve", args.model_dir, - "--disable-log-requests", "--gpu-memory-utilization", "0.95", "--max-num-seqs", @@ -188,12 +189,132 @@ def run_profiling_sweep(args): resume=None, link_vars=[], server_ready_timeout=600, + experiment_name=f"{args.method}-{args.num_spec_tokens}", ) result_df = run_main(sweep_args) return result_df +def get_acceptance_rate_per_pos(args): + """Get acceptance rate per position.""" + + tokenizer = AutoTokenizer.from_pretrained(args.model_dir) + prompts = get_samples(args, tokenizer) + if args.enable_multimodal_chat: + llm_prompts = [p.prompt for p in prompts] + else: + # add_special_tokens is False to avoid adding bos twice + # when using chat templates + llm_prompts = [ + { + "prompt_token_ids": tokenizer.encode( + prompt.prompt, add_special_tokens=False + ), + "multi_modal_data": prompt.multi_modal_data, + } + for prompt in prompts + ] + if args.method == "eagle" or args.method == "eagle3": + eagle_dir = args.eagle_dir + if args.method == "eagle" and eagle_dir is None: + eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" + + elif args.method == "eagle3" and eagle_dir is None: + eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" + speculative_config = { + "method": args.method, + "model": eagle_dir, + "num_speculative_tokens": args.num_spec_tokens, + "disable_padded_drafter_batch": args.disable_padded_drafter_batch, + "parallel_drafting": args.parallel_drafting, + } + elif args.method == "ngram": + speculative_config = { + "method": "ngram", + "num_speculative_tokens": args.num_spec_tokens, + "prompt_lookup_max": args.prompt_lookup_max, + "prompt_lookup_min": args.prompt_lookup_min, + } + elif args.method == "draft_model": + assert args.draft_model is not None and args.draft_model != "" + speculative_config = { + "method": args.method, + "model": args.draft_model, + "num_speculative_tokens": args.num_spec_tokens, + "enforce_eager": args.enforce_eager, + "max_model_len": args.max_model_len, + "parallel_drafting": args.parallel_drafting, + } + elif args.method == "mtp": + speculative_config = { + "method": "mtp", + "num_speculative_tokens": args.num_spec_tokens, + } + else: + raise ValueError(f"unknown method: {args.method}") + + llm = LLM( + model=args.model_dir, + trust_remote_code=True, + tensor_parallel_size=args.tp, + enable_chunked_prefill=args.enable_chunked_prefill, + enforce_eager=args.enforce_eager, + gpu_memory_utilization=args.gpu_memory_utilization, + speculative_config=speculative_config, + disable_log_stats=False, + max_model_len=args.max_model_len, + limit_mm_per_prompt={"image": 5}, + disable_chunked_mm_input=True, + max_num_seqs=args.max_vllm_batch_size, + allowed_local_media_path=args.allowed_local_media_path, + ) + + sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) + if args.backend == "openai-chat": + outputs = llm.chat(llm_prompts, sampling_params=sampling_params) + else: + outputs = llm.generate( + llm_prompts, + sampling_params=sampling_params, + ) + + metrics = llm.get_metrics() + + num_drafts = 0 + num_draft_tokens = 0 + num_accepted_tokens = 0 + acceptance_counts = [0] * args.num_spec_tokens + for metric in metrics: + if metric.name == "vllm:spec_decode_num_drafts": + assert isinstance(metric, Counter) + num_drafts += metric.value + elif metric.name == "vllm:spec_decode_num_draft_tokens": + assert isinstance(metric, Counter) + num_draft_tokens += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens": + assert isinstance(metric, Counter) + num_accepted_tokens += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": + assert isinstance(metric, Vector) + for pos in range(len(metric.values)): + acceptance_counts[pos] += metric.values[pos] + + + acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1 + print(f"mean acceptance length: {acceptance_length:.2f}") + print("-" * 50) + + # print acceptance at each token position + acceptance_rate_per_pos = [] + for i in range(len(acceptance_counts)): + acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0 + print(f"acceptance at token {i}: {acceptance_rate:.2f}") + acceptance_rate_per_pos.append(acceptance_rate) + + return acceptance_length, acceptance_rate_per_pos + + def main(): parser = FlexibleArgumentParser() add_dataset_parser(parser) @@ -217,7 +338,12 @@ def main(): type=int, help="Max vllm server batch size (max concurrency)", ) + parser.add_argument("--backend", type=str, default="openai") parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--enforce-eager", action="store_true") + parser.add_argument("--enable-chunked-prefill", action="store_true") + parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) + parser.add_argument("--disable-padded-drafter-batch", action="store_true") parser.add_argument("--result-dir", type=str, default="./log/dynamic_sd") parser.add_argument("--prompt-lookup-max", type=int, default=5) parser.add_argument("--prompt-lookup-min", type=int, default=2) @@ -226,6 +352,8 @@ def main(): parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-k", type=int, default=-1) parser.add_argument("--output-len", type=int, default=256) + parser.add_argument("--parallel-drafting", action="store_true") + parser.add_argument("--allowed-local-media-path", type=str, default="") parser.add_argument( "--num-batches", type=int, @@ -254,7 +382,7 @@ def main(): start = time.time() # Step 1: get acceptance_rate_per_pos - acceptance_length, acceptance_rate_per_pos = spec_decode_main(args) + acceptance_length, acceptance_rate_per_pos = get_acceptance_rate_per_pos(args) print(f"Acceptance length: {acceptance_length}") print(f"Acceptance rate per position: {acceptance_rate_per_pos}") print("✅ Step 1: obtained acceptance rate per position.") diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d3c94cd1cbad..cb3c5c75b324 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4243,6 +4243,11 @@ def propose_draft_token_ids( common_attn_metadata: CommonAttentionMetadata, slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None, ) -> list[list[int]] | torch.Tensor: + + # REMOVE + print(f"scheduler_output.spec_decoding_stats_all: {scheduler_output.spec_decoding_stats_all}") + print(f"self.input_batch.num_reqs: {self.input_batch.num_reqs}") + optimal_num_speculative_tokens = None if self.dynamic_sd_manager: optimal_num_speculative_tokens = self.dynamic_sd_manager.step( @@ -4250,6 +4255,8 @@ def propose_draft_token_ids( self.input_batch.num_reqs, ) + print(f"optimal_num_speculative_tokens: {optimal_num_speculative_tokens}") + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens spec_config = self.speculative_config assert spec_config is not None From 3f7196ea51dfd714aa0a929aaa9f9360c3fde87a Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Fri, 13 Mar 2026 04:28:07 +0000 Subject: [PATCH 25/39] fix padding for async sched Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- .../v1/spec_decode/dynamic/generate_config.py | 14 ++++++------- vllm/v1/worker/gpu_model_runner.py | 21 ++++++++++++------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/vllm/v1/spec_decode/dynamic/generate_config.py b/vllm/v1/spec_decode/dynamic/generate_config.py index d3e38691f361..35692f2cf73a 100644 --- a/vllm/v1/spec_decode/dynamic/generate_config.py +++ b/vllm/v1/spec_decode/dynamic/generate_config.py @@ -5,17 +5,18 @@ import time from pathlib import Path -from vllm.benchmarks.datasets import add_dataset_parser -from vllm.benchmarks.datasets import get_samples +from transformers import AutoTokenizer + +from vllm import LLM +from vllm.benchmarks.datasets import add_dataset_parser, get_samples from vllm.benchmarks.sweep.param_sweep import ParameterSweep from vllm.benchmarks.sweep.serve import SweepServeArgs, run_main from vllm.config.speculative import DynamicSpeculativeConfig -from vllm import LLM -from transformers import AutoTokenizer from vllm.sampling_params import SamplingParams from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.v1.metrics.reader import Counter, Vector + def build_serve_params( method, draft_dir, @@ -272,9 +273,9 @@ def get_acceptance_rate_per_pos(args): sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) if args.backend == "openai-chat": - outputs = llm.chat(llm_prompts, sampling_params=sampling_params) + _ = llm.chat(llm_prompts, sampling_params=sampling_params) else: - outputs = llm.generate( + _ = llm.generate( llm_prompts, sampling_params=sampling_params, ) @@ -300,7 +301,6 @@ def get_acceptance_rate_per_pos(args): for pos in range(len(metric.values)): acceptance_counts[pos] += metric.values[pos] - acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1 print(f"mean acceptance length: {acceptance_length:.2f}") print("-" * 50) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cb3c5c75b324..a793b81fa601 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4243,11 +4243,6 @@ def propose_draft_token_ids( common_attn_metadata: CommonAttentionMetadata, slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None, ) -> list[list[int]] | torch.Tensor: - - # REMOVE - print(f"scheduler_output.spec_decoding_stats_all: {scheduler_output.spec_decoding_stats_all}") - print(f"self.input_batch.num_reqs: {self.input_batch.num_reqs}") - optimal_num_speculative_tokens = None if self.dynamic_sd_manager: optimal_num_speculative_tokens = self.dynamic_sd_manager.step( @@ -4255,8 +4250,6 @@ def propose_draft_token_ids( self.input_batch.num_reqs, ) - print(f"optimal_num_speculative_tokens: {optimal_num_speculative_tokens}") - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens spec_config = self.speculative_config assert spec_config is not None @@ -4505,6 +4498,20 @@ def propose_draft_token_ids( slot_mappings=slot_mappings, ) + if ( + optimal_num_speculative_tokens is not None + and isinstance(draft_token_ids, torch.Tensor) + and draft_token_ids.dim() == 2 + and draft_token_ids.shape[1] < self.num_spec_tokens + ): + padding = torch.zeros( + draft_token_ids.shape[0], + self.num_spec_tokens - draft_token_ids.shape[1], + device=draft_token_ids.device, + dtype=draft_token_ids.dtype, + ) + draft_token_ids = torch.cat([draft_token_ids, padding], dim=1) + return draft_token_ids def update_config(self, overrides: dict[str, Any]) -> None: From 64fab8dc13525b89993904bc473a3125f4bc09b3 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Fri, 13 Mar 2026 00:29:27 -0400 Subject: [PATCH 26/39] Update vllm/config/speculative.py Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/config/speculative.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index b86b31502c90..87b33004332c 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -60,7 +60,8 @@ RejectionSampleMethod = Literal["strict", "probabilistic"] -class DynamicSpeculativeConfig(BaseModel): +@config +class DynamicSpeculativeConfig: # """A mapping from batch size to optimal number of drafts to use for that # batch size. This is used to dynamically adjust the number of drafts used # based on the current batch size.""" From 1198aa742ff28694b3909c9231be7e5ed853e2d8 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Fri, 13 Mar 2026 04:38:20 +0000 Subject: [PATCH 27/39] fix docstring Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/config/speculative.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 87b33004332c..ece489e7e22d 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -5,7 +5,7 @@ import copy from typing import TYPE_CHECKING, Any, Literal, get_args -from pydantic import BaseModel, Field, SkipValidation, model_validator +from pydantic import Field, SkipValidation, model_validator from typing_extensions import Self from vllm.config import LoadConfig @@ -62,15 +62,14 @@ @config class DynamicSpeculativeConfig: - # """A mapping from batch size to optimal number of drafts to use for that - # batch size. This is used to dynamically adjust the number of drafts used - # based on the current batch size.""" - # optimal_num_speculative_tokens: dict[int, int] = None - - """Whether the statistics are updated online or not during inference.""" + """A mapping from batch size to optimal number of drafts to use for that + batch size. This is used to dynamically adjust the number of drafts used + based on the current batch size.""" is_online: bool = False + """Whether the statistics are updated online or not during inference.""" + batch_stats: dict[int, dict[int, float]] = None """ Batch statistics for different batch sizes and number of drafts. The structure is as follows: @@ -88,13 +87,12 @@ class DynamicSpeculativeConfig: where bs 1 at K=3 has itl 9.41ms. K=0 means no speculative decoding. """ - batch_stats: dict[int, dict[int, float]] = None - """Maximum number of speculative tokens supported in the statistics.""" max_num_speculative_tokens: int = None + """Maximum number of speculative tokens supported in the statistics.""" - """Acceptance rate per position on an offline dataset.""" acceptance_rate_per_pos: list[float] = None + """Acceptance rate per position on an offline dataset.""" @config From c07afd14262115068550d3395f93d85d7c643b5f Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Fri, 13 Mar 2026 15:07:50 +0000 Subject: [PATCH 28/39] lint Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/config/speculative.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index ece489e7e22d..4b7ed0b07c29 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -648,7 +648,7 @@ def __post_init__(self): with open(self.dynamic_config_path) as f: data = json.load(f) - self.dynamic_config = DynamicSpeculativeConfig.model_validate(data) + self.dynamic_config = DynamicSpeculativeConfig(**data) else: self.dynamic_config = None From 594dc0f84996b72c38a3d58ff36d63cd751d1de9 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Sat, 14 Mar 2026 02:40:11 +0000 Subject: [PATCH 29/39] make DSD compat with async and padded drafter Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/core/sched/async_scheduler.py | 30 ++++++++++++++++++++++++++ vllm/v1/outputs.py | 4 ++++ vllm/v1/spec_decode/dynamic/manager.py | 1 - vllm/v1/spec_decode/eagle.py | 11 ++++++++++ vllm/v1/worker/gpu_model_runner.py | 9 ++++++++ 5 files changed, 54 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py index 0b3958dbcf5a..a69035a78234 100644 --- a/vllm/v1/core/sched/async_scheduler.py +++ b/vllm/v1/core/sched/async_scheduler.py @@ -1,11 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +from typing import TYPE_CHECKING + from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.request import Request, RequestStatus +if TYPE_CHECKING: + from vllm.v1.engine import EngineCoreOutputs + from vllm.v1.outputs import ModelRunnerOutput + logger = init_logger(__name__) @@ -14,6 +22,28 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # reusable read-only placeholder list for speculative decoding. self._spec_token_placeholders: list[int] = [-1] * self.num_spec_tokens + # Dynamic SD: effective K for the current step, updated from + # ModelRunnerOutput after each model execution. + self._dynamic_num_spec_tokens: int | None = None + + def _update_placeholders_from_dynamic_sd(self, optimal_k: int | None) -> None: + """Update placeholder count based on Dynamic SD decision.""" + if optimal_k is None: + self._dynamic_num_spec_tokens = None + self._spec_token_placeholders = [-1] * self.num_spec_tokens + elif optimal_k != self._dynamic_num_spec_tokens: + self._dynamic_num_spec_tokens = optimal_k + self._spec_token_placeholders = [-1] * optimal_k + + def update_from_output( + self, + scheduler_output: SchedulerOutput, + model_runner_output: ModelRunnerOutput, + ) -> dict[int, EngineCoreOutputs]: + self._update_placeholders_from_dynamic_sd( + model_runner_output.optimal_num_speculative_tokens + ) + return super().update_from_output(scheduler_output, model_runner_output) def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None: super()._update_after_schedule(scheduler_output) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 8eb58de4f3fd..221c705c2142 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -253,6 +253,10 @@ class ModelRunnerOutput: # information related to cudagraph execution cudagraph_stats: CUDAGraphStat | None = None + # Dynamic Speculative Decoding: optimal K chosen for this step. + # None means DSD is not active (use the static num_spec_tokens). + optimal_num_speculative_tokens: int | None = None + # ModelRunnerOutput wrapper for async scheduling. class AsyncModelRunnerOutput(ABC): diff --git a/vllm/v1/spec_decode/dynamic/manager.py b/vllm/v1/spec_decode/dynamic/manager.py index 3278293fab2c..1a983c57a631 100644 --- a/vllm/v1/spec_decode/dynamic/manager.py +++ b/vllm/v1/spec_decode/dynamic/manager.py @@ -113,7 +113,6 @@ def update_acceptance_rate_per_pos(self, acceptance_rate_per_pos: list[float]): self.update_optimal_num_speculative_tokens() def _get_batch_stats(self, batch_size: int) -> dict: - # import pdb; pdb.set_trace() if batch_size not in self.batch_stats: # find the nearest batch size smaller and bigger than the given batch size # and return the weighted avg of their stats diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 11fddfe217eb..0f174c6ca737 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -485,6 +485,17 @@ def propose( sample_hidden_states = last_hidden_states[token_indices_to_sample] + # No draft tokens requested (e.g. Dynamic SD decided K=0). + # The prefill forward pass above already ran to keep the drafter + # KV cache in sync, so just return an empty tensor. + if self.num_speculative_tokens == 0: + return torch.empty( + batch_size, + 0, + device=sample_hidden_states.device, + dtype=torch.int64, + ) + # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1 or self.parallel_drafting: draft_token_ids = self._greedy_sample(sample_hidden_states) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a793b81fa601..26862beeb10d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -767,6 +767,7 @@ def __init__( # Cached outputs. self._draft_token_ids: list[list[int]] | torch.Tensor | None = None + self._optimal_num_speculative_tokens: int | None = None # N-gram GPU path: async D2H buffer/event for per-request valid draft counts. self._num_valid_draft_tokens: torch.Tensor | None = None self._num_valid_draft_tokens_cpu: torch.Tensor | None = None @@ -4083,6 +4084,7 @@ def propose_draft_token_ids(sampled_token_ids): else None, num_nans_in_logits=num_nans_in_logits, cudagraph_stats=cudagraph_stats, + optimal_num_speculative_tokens=(self._optimal_num_speculative_tokens), ) if not self.use_async_scheduling: @@ -4154,6 +4156,12 @@ def take_draft_token_ids(self) -> DraftTokenIds | None: if not self.num_spec_tokens or not self._draft_token_req_ids: return None draft_token_ids, req_ids = self._get_draft_token_ids_cpu() + # When Dynamic SD reduced the number of speculative tokens, + # the GPU tensor was zero-padded to num_spec_tokens for scatter + # indexing, but the scheduler should only see the real draft tokens. + k = self._optimal_num_speculative_tokens + if k is not None and k < self.num_spec_tokens: + draft_token_ids = [ids[:k] for ids in draft_token_ids] return DraftTokenIds(req_ids, draft_token_ids) def _copy_draft_token_ids_to_cpu( @@ -4249,6 +4257,7 @@ def propose_draft_token_ids( scheduler_output.spec_decoding_stats_all, self.input_batch.num_reqs, ) + self._optimal_num_speculative_tokens = optimal_num_speculative_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens spec_config = self.speculative_config From cd1975019a39452e6b1e9fe6727a8577d4075c51 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Sat, 14 Mar 2026 02:49:12 +0000 Subject: [PATCH 30/39] lint Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/spec_decode/dynamic/generate_config.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/dynamic/generate_config.py b/vllm/v1/spec_decode/dynamic/generate_config.py index 35692f2cf73a..c20aed5ed5cf 100644 --- a/vllm/v1/spec_decode/dynamic/generate_config.py +++ b/vllm/v1/spec_decode/dynamic/generate_config.py @@ -1,9 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import dataclasses import json import pprint import time from pathlib import Path +from typing import Any from transformers import AutoTokenizer @@ -187,7 +189,7 @@ def run_profiling_sweep(args): output_dir=Path(args.result_dir), num_runs=1, dry_run=False, - resume=None, + resume=False, link_vars=[], server_ready_timeout=600, experiment_name=f"{args.method}-{args.num_spec_tokens}", @@ -202,6 +204,7 @@ def get_acceptance_rate_per_pos(args): tokenizer = AutoTokenizer.from_pretrained(args.model_dir) prompts = get_samples(args, tokenizer) + llm_prompts: list[Any] if args.enable_multimodal_chat: llm_prompts = [p.prompt for p in prompts] else: @@ -410,7 +413,7 @@ def main(): config_path = f"{args.result_dir}/dynamic_speculative_config.json" with open(config_path, "w") as f: - json.dump(dynamic_config.model_dump(), f, indent=4) + json.dump(dataclasses.asdict(dynamic_config), f, indent=4) print(f"✅ Step 4: config saved to {config_path}") From 4307feba5471537b86777410b9269fd9b7cc238a Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Tue, 17 Mar 2026 04:18:30 +0000 Subject: [PATCH 31/39] optimize DSD async scheduling by minimizing delay in propagating optimal K Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/core/sched/async_scheduler.py | 67 +++++++++++++++++++++++++-- vllm/v1/core/sched/scheduler.py | 2 - vllm/v1/worker/gpu_model_runner.py | 37 +++++++++++++++ 3 files changed, 101 insertions(+), 5 deletions(-) diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py index a69035a78234..bb4f3f9a11ff 100644 --- a/vllm/v1/core/sched/async_scheduler.py +++ b/vllm/v1/core/sched/async_scheduler.py @@ -26,6 +26,15 @@ def __init__(self, *args, **kwargs) -> None: # ModelRunnerOutput after each model execution. self._dynamic_num_spec_tokens: int | None = None + # Dynamic SD async scheduling state: + # optimal K learned from the latest model output, to be applied + # at the beginning of the next schedule() call. + self._pending_optimal_k: int | None = None + # req_id -> committed spec token count from the most recent + # _update_after_schedule. Used to correct accounting when + # the optimal K changes between scheduling and output processing. + self._in_flight_decode_req_k: dict[str, int] = {} + def _update_placeholders_from_dynamic_sd(self, optimal_k: int | None) -> None: """Update placeholder count based on Dynamic SD decision.""" if optimal_k is None: @@ -35,19 +44,68 @@ def _update_placeholders_from_dynamic_sd(self, optimal_k: int | None) -> None: self._dynamic_num_spec_tokens = optimal_k self._spec_token_placeholders = [-1] * optimal_k + def _apply_pending_dynamic_sd_update(self) -> None: + """Apply a deferred dynamic SD K change at the start of schedule(). + + When update_from_output() learns a new optimal K, it cannot + immediately correct the in-flight step's accounting because we + need to target the right set of requests (those committed in the + most recent _update_after_schedule, tracked in + _in_flight_decode_req_k). This method is called at the beginning + of schedule() so that: + 1. The in-flight step's over/under-committed accounting is fixed. + 2. request.spec_token_ids is updated before the scheduling loop + reads it. + 3. _spec_token_placeholders reflects the new K for the current + scheduling step. + """ + optimal_k = self._pending_optimal_k + if optimal_k is None: + return + + self._pending_optimal_k = None + self._update_placeholders_from_dynamic_sd(optimal_k) + + for req_id, committed_k in self._in_flight_decode_req_k.items(): + diff = committed_k - optimal_k + if diff <= 0: + # K stayed the same or increased; the in-flight step + # under-allocated (if anything) but we cannot retroactively + # add more spec tokens to a step already on the GPU. + # Just update spec_token_ids for the next scheduling step. + request = self.requests.get(req_id) + if request is not None and not request.is_finished(): + request.spec_token_ids = self._spec_token_placeholders + continue + + request = self.requests.get(req_id) + if request is None or request.is_finished(): + continue + + request.num_output_placeholders -= diff + request.num_computed_tokens -= diff + request.spec_token_ids = self._spec_token_placeholders + + self._in_flight_decode_req_k = {} + + def schedule(self) -> SchedulerOutput: + self._apply_pending_dynamic_sd_update() + return super().schedule() + def update_from_output( self, scheduler_output: SchedulerOutput, model_runner_output: ModelRunnerOutput, ) -> dict[int, EngineCoreOutputs]: - self._update_placeholders_from_dynamic_sd( - model_runner_output.optimal_num_speculative_tokens - ) + optimal_k = model_runner_output.optimal_num_speculative_tokens + if optimal_k is not None: + self._pending_optimal_k = optimal_k return super().update_from_output(scheduler_output, model_runner_output) def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None: super()._update_after_schedule(scheduler_output) spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens + self._in_flight_decode_req_k = {} for req_id in scheduler_output.num_scheduled_tokens: request = self.requests[req_id] if request.is_prefill_chunk: @@ -64,6 +122,9 @@ def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None: # We will update the actual spec token ids in the worker process. request.spec_token_ids = self._spec_token_placeholders + if cur_num_spec_tokens > 0: + self._in_flight_decode_req_k[req_id] = cur_num_spec_tokens + def _update_request_with_output( self, request: Request, new_token_ids: list[int] ) -> tuple[list[int], bool]: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 92ae397d8144..97194c692077 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -908,8 +908,6 @@ def schedule(self) -> SchedulerOutput: new_block_ids_to_zero=new_block_ids_to_zero, ) - # REMOVE - # NOTE(Kuntai): this function is designed for multiple purposes: # 1. Plan the KV cache store # 2. Wrap up all the KV cache load / save ops into an opaque object diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c4d10fb56507..ecfbfe2b6319 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1163,6 +1163,16 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_data = scheduler_output.scheduled_cached_reqs scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens + # When dynamic SD chose a smaller K in the previous step, trim the + # scheduler output in-place so that input preparation, rejection + # accounting, and prev_num_draft_len all reflect the actual K. + if ( + self._optimal_num_speculative_tokens is not None + and self.use_async_scheduling + and scheduled_spec_tokens + ): + self._trim_spec_tokens_for_dynamic_sd(scheduler_output) + # Save scheduler-allocated spec lengths before trimming so # prev_num_draft_len keeps the optimistic count for rejection correction. original_num_spec_per_req: dict[str, int] = {} @@ -4152,6 +4162,33 @@ def _pp_receive_prev_sampled_token_ids_to_input_batch(self) -> None: req_state.output_token_ids.append(-1) self.input_batch.prev_req_id_to_index = prev_req_id_to_index + def _trim_spec_tokens_for_dynamic_sd( + self, scheduler_output: "SchedulerOutput" + ) -> None: + """Trim scheduled spec tokens to match dynamic SD's optimal K. + + Called in _update_states when async scheduling is active and the + previous step determined a lower optimal K than what the scheduler + allocated. Modifies scheduler_output in-place so that the engine + core's update_from_output sees the trimmed counts in its rejection + logic, avoiding double-correction with the scheduler-side fix. + """ + k = self._optimal_num_speculative_tokens + assert k is not None + spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens + for req_id in list(spec_decode_tokens): + tokens = spec_decode_tokens[req_id] + scheduled_k = len(tokens) + if scheduled_k <= k: + continue + tokens_to_trim = scheduled_k - k + scheduler_output.total_num_scheduled_tokens -= tokens_to_trim + scheduler_output.num_scheduled_tokens[req_id] -= tokens_to_trim + if k == 0: + spec_decode_tokens.pop(req_id) + else: + spec_decode_tokens[req_id] = tokens[:k] + def take_draft_token_ids(self) -> DraftTokenIds | None: if not self.num_spec_tokens or not self._draft_token_req_ids: return None From 018b4bd704b08f19fdfff055c82b3c361d5f8ab5 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Tue, 17 Mar 2026 04:23:58 +0000 Subject: [PATCH 32/39] dsd config path field Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/config/speculative.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 4b7ed0b07c29..f5046601166c 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -179,6 +179,8 @@ class SpeculativeConfig: # dynamic speculative decoding control dynamic_config_path: str | None = None """Path to config file for dynamic speculative decoding, if provided.""" + dynamic_config: SkipValidation[DynamicSpeculativeConfig] | None = None + """Loaded dynamic speculative config, populated from dynamic_config_path.""" # params generated in the post-init stage draft_model_config: SkipValidation[ModelConfig] = None # type: ignore @@ -649,8 +651,6 @@ def __post_init__(self): data = json.load(f) self.dynamic_config = DynamicSpeculativeConfig(**data) - else: - self.dynamic_config = None return self From 36f5a36a46d55faaa4c28050babdc76ef7374ea0 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Tue, 17 Mar 2026 04:54:37 +0000 Subject: [PATCH 33/39] refactor to simplify propose signature and update test Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- tests/v1/spec_decode/test_eagle.py | 2 ++ .../spec_decode/test_extract_hidden_states.py | 2 ++ tests/v1/spec_decode/test_mtp.py | 1 + tests/v1/spec_decode/test_ngram.py | 10 +++++++ vllm/v1/spec_decode/eagle.py | 6 ++-- vllm/v1/spec_decode/extract_hidden_states.py | 7 ++++- vllm/v1/spec_decode/medusa.py | 4 ++- vllm/v1/spec_decode/ngram_proposer.py | 14 ++++++---- vllm/v1/spec_decode/ngram_proposer_gpu.py | 3 ++ vllm/v1/spec_decode/suffix_decoding.py | 2 ++ vllm/v1/worker/gpu_model_runner.py | 28 +++++++++++++------ 11 files changed, 58 insertions(+), 21 deletions(-) diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 6ac68e055e57..07d3b8cded86 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -979,6 +979,7 @@ def create_deterministic_logits(token_ids): proposer.draft_attn_groups = [mock_attn_group] result = proposer.propose( + num_speculative_tokens=num_speculative_tokens, target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, @@ -1134,6 +1135,7 @@ def create_deterministic_logits(token_ids, k: int): # Propose draft tokens. result = proposer.propose( + num_speculative_tokens=num_speculative_tokens, target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, diff --git a/tests/v1/spec_decode/test_extract_hidden_states.py b/tests/v1/spec_decode/test_extract_hidden_states.py index af911e91d4b3..a788cd49ca82 100644 --- a/tests/v1/spec_decode/test_extract_hidden_states.py +++ b/tests/v1/spec_decode/test_extract_hidden_states.py @@ -264,6 +264,7 @@ def test_propose(): mock_has_kv.return_value = False draft_tokens, kv_connector_output = proposer.propose( + num_speculative_tokens=1, sampled_token_ids=sampled_token_ids, target_hidden_states=target_hidden_states, common_attn_metadata=common_attn_metadata, @@ -335,6 +336,7 @@ def test_propose_different_layer_counts(num_hidden_layers): mock_has_kv.return_value = False draft_tokens, _ = proposer.propose( + num_speculative_tokens=1, sampled_token_ids=sampled_token_ids, target_hidden_states=target_hidden_states, common_attn_metadata=common_attn_metadata, diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py index 0a48b0e7b98c..9fa228fb05d5 100644 --- a/tests/v1/spec_decode/test_mtp.py +++ b/tests/v1/spec_decode/test_mtp.py @@ -204,6 +204,7 @@ def create_deterministic_logits(batch_size, vocab_size, token_offset): # Run propose result = proposer.propose( + num_speculative_tokens=num_speculative_tokens, target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index 7d2a07ddcec7..459edddd1c2e 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -81,6 +81,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # No match. token_ids_cpu = np.array([[1, 2, 3, 4, 5]]) result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( + num_speculative_tokens=2, sampled_token_ids=[[0]], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -90,6 +91,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # No match for 4-gram. token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose( + num_speculative_tokens=2, sampled_token_ids=[[0]], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -99,6 +101,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # No match for 4-gram but match for 3-gram. token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( + num_speculative_tokens=2, sampled_token_ids=[[0]], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -109,6 +112,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # In this case, the proposer should return the 4-gram match. token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]]) result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( + num_speculative_tokens=2, sampled_token_ids=[[0]], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -118,6 +122,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # Match for 2-gram and 3-gram, but not 4-gram. token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]]) result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose( + num_speculative_tokens=2, sampled_token_ids=[[0]], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -127,6 +132,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # Multiple 3-gram matched, but always pick the first one. token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]]) result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose( + num_speculative_tokens=2, sampled_token_ids=[[0]], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -136,6 +142,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # check empty input token_ids_cpu = np.array([[]]) result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( + num_speculative_tokens=2, sampled_token_ids=[[0]], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -147,6 +154,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # second request has 3 tokens and no match. Padded with -1 for max len 5 token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]]) result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( + num_speculative_tokens=2, sampled_token_ids=[[0], [1]], num_tokens_no_spec=np.array([5, 3]), token_ids_cpu=token_ids_cpu, @@ -166,6 +174,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: num_tokens_no_spec = np.array([5, 3, 5], dtype=np.int32) sampled_token_ids = [[2], [], [8]] # Empty list for request 1 simulates prefill result = proposer.propose( + num_speculative_tokens=2, sampled_token_ids=sampled_token_ids, num_tokens_no_spec=num_tokens_no_spec, token_ids_cpu=token_ids_cpu, @@ -195,6 +204,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: input_2[:3] = [4, 5, 6] token_ids_cpu = np.array([input_1, input_2]) result = ngram_proposer.propose( + num_speculative_tokens=2, sampled_token_ids=[[0], [1]], num_tokens_no_spec=np.array([len(input_1), 3]), token_ids_cpu=token_ids_cpu, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 0f174c6ca737..099baa61e4e2 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -384,7 +384,7 @@ def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor: def propose( self, - optimal_num_speculative_tokens: int | None, + num_speculative_tokens: int, # [num_tokens] target_token_ids: torch.Tensor, # [num_tokens] or [3, num_tokens] when M-RoPE is enabled @@ -402,9 +402,7 @@ def propose( | list[dict[str, torch.Tensor]] | None = None, ) -> torch.Tensor: - # Use optimal num speculative tokens if provided - if optimal_num_speculative_tokens is not None: - self.num_speculative_tokens = optimal_num_speculative_tokens + self.num_speculative_tokens = num_speculative_tokens batch_size = common_attn_metadata.batch_size() diff --git a/vllm/v1/spec_decode/extract_hidden_states.py b/vllm/v1/spec_decode/extract_hidden_states.py index 38a54f01696c..b3f7ab37bc29 100644 --- a/vllm/v1/spec_decode/extract_hidden_states.py +++ b/vllm/v1/spec_decode/extract_hidden_states.py @@ -32,7 +32,10 @@ class ExtractHiddenStatesProposer: def __init__(self, vllm_config: VllmConfig, device): assert vllm_config.speculative_config is not None - assert vllm_config.speculative_config.num_speculative_tokens == 1 + self.num_speculative_tokens = ( + vllm_config.speculative_config.num_speculative_tokens + ) + assert self.num_speculative_tokens == 1 if vllm_config.speculative_config.disable_padded_drafter_batch: raise ValueError( "disable_padded_drafter_batch is not supported with " @@ -76,6 +79,7 @@ def __init__(self, vllm_config: VllmConfig, device): def propose( self, + num_speculative_tokens: int, sampled_token_ids: torch.Tensor, target_hidden_states: list[torch.Tensor], common_attn_metadata: CommonAttentionMetadata, @@ -108,6 +112,7 @@ def propose( - Draft tokens matching sampled tokens, shape [batch_size, 1] - KV connector output (if KV transfer is active), else None """ + assert num_speculative_tokens == self.num_speculative_tokens assert self.model is not None and isinstance(target_hidden_states, list) # target_hidden_states is a list of tensors (one per layer) diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index b3d8745e19c5..7adf7cff5f77 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -35,16 +35,18 @@ def __init__( self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens self.hidden_size = self.spec_config.draft_model_config.get_hidden_size() self.dtype = vllm_config.model_config.dtype - self.num_speculative_tokens = None + self.num_speculative_tokens = self.spec_config.num_speculative_tokens def propose( self, + num_speculative_tokens: int, target_hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None, # unused ) -> torch.Tensor: + assert num_speculative_tokens == self.num_speculative_tokens # Generate blocks and compute logits blocks = self.model(target_hidden_states) logits = self.model.compute_logits(blocks) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index d20a10ee4d3b..e0240d0e66b8 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -55,7 +55,7 @@ def __init__(self, vllm_config: VllmConfig): # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. self.propose( - None, + self.k, [[]] * 1024, np.zeros(1024, dtype=np.int32), np.zeros((1024, self.max_model_len), dtype=np.int32), @@ -67,6 +67,7 @@ def batch_propose( valid_ngram_requests: list, num_tokens_no_spec: np.ndarray, token_ids_cpu: np.ndarray, + k: int, ) -> list[list[int]]: """Batch version of ngram proposer using numba for acceleration. @@ -79,6 +80,8 @@ def batch_propose( token_ids_cpu: Numpy array of shape (batch_size, max_model_len) representing the token IDs for each request. + k: + Number of speculative tokens to propose. Returns: list[list[int]]: @@ -111,7 +114,7 @@ def batch_propose( self.min_n, self.max_n, self.max_model_len, - self.k, + k, self.valid_ngram_draft, self.valid_ngram_num_drafts, ) @@ -131,7 +134,7 @@ def batch_propose( def propose( self, - optimal_num_speculative_tokens: int | None, + num_speculative_tokens: int, sampled_token_ids: list[list[int]], num_tokens_no_spec: np.ndarray, token_ids_cpu: np.ndarray, @@ -139,9 +142,7 @@ def propose( | list[dict[str, torch.Tensor]] | None = None, # unused ) -> list[list[int]]: - # Use optimal num speculative tokens if provided - if optimal_num_speculative_tokens is not None: - self.k = optimal_num_speculative_tokens + assert num_speculative_tokens <= self.k # find which requests need ngram proposals valid_ngram_requests = [] @@ -163,6 +164,7 @@ def propose( valid_ngram_requests, num_tokens_no_spec, token_ids_cpu, + num_speculative_tokens, ) return draft_token_ids diff --git a/vllm/v1/spec_decode/ngram_proposer_gpu.py b/vllm/v1/spec_decode/ngram_proposer_gpu.py index 3ff84180463d..3837618082cf 100644 --- a/vllm/v1/spec_decode/ngram_proposer_gpu.py +++ b/vllm/v1/spec_decode/ngram_proposer_gpu.py @@ -313,6 +313,7 @@ def _generate_dummy_data( def propose( self, + num_speculative_tokens: int, num_tokens_no_spec: torch.Tensor, # [batch_size] token_ids_gpu: torch.Tensor, # [batch_size, max_len] valid_sampled_token_ids_gpu: torch.Tensor, # [batch_size, num_spec_tokens + 1] @@ -325,6 +326,7 @@ def propose( updated lengths, then run the kernel. Args: + num_speculative_tokens: Number of speculative tokens to propose. num_tokens_no_spec: Number of tokens per sequence (read-only) token_ids_gpu: Token IDs tensor (modified in-place with new tokens) valid_sampled_token_ids_gpu: Newly sampled tokens to scatter @@ -335,6 +337,7 @@ def propose( num_valid_draft_tokens: Count of leading valid draft tokens per request [batch_size] """ + assert num_speculative_tokens == self.k assert token_ids_gpu.device == self.device assert num_tokens_no_spec.device == self.device diff --git a/vllm/v1/spec_decode/suffix_decoding.py b/vllm/v1/spec_decode/suffix_decoding.py index fee5d97468f3..66137a006316 100644 --- a/vllm/v1/spec_decode/suffix_decoding.py +++ b/vllm/v1/spec_decode/suffix_decoding.py @@ -34,12 +34,14 @@ def __init__(self, vllm_config: VllmConfig): def propose( self, + num_speculative_tokens: int, input_batch: InputBatch, sampled_token_ids: list[list[int]], slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None, # unused ) -> list[list[int]]: + assert num_speculative_tokens == self.num_speculative_tokens """ Propose speculative tokens for each request in the input batch. Suffix Decoding will speculate a dynamic number of tokens for each request every decoding step, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ecfbfe2b6319..0ebc872f0ec6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4288,24 +4288,28 @@ def propose_draft_token_ids( common_attn_metadata: CommonAttentionMetadata, slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None, ) -> list[list[int]] | torch.Tensor: - optimal_num_speculative_tokens = None + spec_config = self.speculative_config + assert spec_config is not None + if self.dynamic_sd_manager: - optimal_num_speculative_tokens = self.dynamic_sd_manager.step( + num_speculative_tokens = self.dynamic_sd_manager.step( scheduler_output.spec_decoding_stats_all, self.input_batch.num_reqs, ) - self._optimal_num_speculative_tokens = optimal_num_speculative_tokens + else: + num_speculative_tokens = spec_config.num_speculative_tokens + self._optimal_num_speculative_tokens = ( + num_speculative_tokens if self.dynamic_sd_manager else None + ) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - spec_config = self.speculative_config - assert spec_config is not None if spec_config.method == "ngram": from vllm.v1.spec_decode.ngram_proposer import NgramProposer assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.drafter.propose( - optimal_num_speculative_tokens, + num_speculative_tokens, sampled_token_ids, self.input_batch.num_tokens_no_spec, self.input_batch.token_ids_cpu, @@ -4331,6 +4335,7 @@ def propose_draft_token_ids( batch_size = next_token_ids.shape[0] draft_token_ids, num_valid_draft_tokens = self.drafter.propose( + num_speculative_tokens, self.num_tokens_no_spec_gpu[:batch_size], self.token_ids_gpu_tensor[:batch_size], valid_sampled_token_ids_gpu, @@ -4352,7 +4357,10 @@ def propose_draft_token_ids( assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, SuffixDecodingProposer) draft_token_ids = self.drafter.propose( - self.input_batch, sampled_token_ids, slot_mappings=slot_mappings + num_speculative_tokens, + self.input_batch, + sampled_token_ids, + slot_mappings=slot_mappings, ) elif spec_config.method == "medusa": assert isinstance(sampled_token_ids, list) @@ -4376,6 +4384,7 @@ def propose_draft_token_ids( hidden_states = sample_hidden_states[indices] draft_token_ids = self.drafter.propose( + num_speculative_tokens=num_speculative_tokens, target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, slot_mappings=slot_mappings, @@ -4393,6 +4402,7 @@ def propose_draft_token_ids( target_hidden_states = [h[:num_scheduled_tokens] for h in aux_hidden_states] draft_token_ids, drafter_kv_connector_output = self.drafter.propose( + num_speculative_tokens=num_speculative_tokens, sampled_token_ids=sampled_token_ids, target_hidden_states=target_hidden_states, common_attn_metadata=common_attn_metadata, @@ -4522,7 +4532,7 @@ def propose_draft_token_ids( mm_embed_inputs = None draft_token_ids = self.drafter.propose( - optimal_num_speculative_tokens=optimal_num_speculative_tokens, + num_speculative_tokens=num_speculative_tokens, target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, @@ -4536,7 +4546,7 @@ def propose_draft_token_ids( ) if ( - optimal_num_speculative_tokens is not None + self.dynamic_sd_manager and isinstance(draft_token_ids, torch.Tensor) and draft_token_ids.dim() == 2 and draft_token_ids.shape[1] < self.num_spec_tokens From 046c39aa5606f08fea55136661b4fb68df8034c9 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Tue, 17 Mar 2026 05:18:31 +0000 Subject: [PATCH 34/39] fix comma Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d5db15b7a810..43ca224e11ed 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4402,7 +4402,7 @@ def propose_draft_token_ids( target_hidden_states = [h[:num_scheduled_tokens] for h in aux_hidden_states] draft_token_ids = self.drafter.propose( - num_speculative_tokens=num_speculative_tokens + num_speculative_tokens=num_speculative_tokens, sampled_token_ids=sampled_token_ids, target_hidden_states=target_hidden_states, common_attn_metadata=common_attn_metadata, From 917e3debbc9f75837eac895d092d7961b59f21d1 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Tue, 31 Mar 2026 03:47:32 +0000 Subject: [PATCH 35/39] move towards DSD scheduler Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/core/sched/async_scheduler.py | 88 ------------------------- vllm/v1/core/sched/output.py | 4 ++ vllm/v1/core/sched/scheduler.py | 27 ++++++-- vllm/v1/spec_decode/dynamic/__init__.py | 2 + vllm/v1/spec_decode/dynamic/manager.py | 8 +++ vllm/v1/worker/gpu_model_runner.py | 81 +++-------------------- 6 files changed, 45 insertions(+), 165 deletions(-) create mode 100644 vllm/v1/spec_decode/dynamic/__init__.py diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py index bb4f3f9a11ff..32992aaf9e0d 100644 --- a/vllm/v1/core/sched/async_scheduler.py +++ b/vllm/v1/core/sched/async_scheduler.py @@ -1,19 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - -from typing import TYPE_CHECKING - from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.request import Request, RequestStatus -if TYPE_CHECKING: - from vllm.v1.engine import EngineCoreOutputs - from vllm.v1.outputs import ModelRunnerOutput - logger = init_logger(__name__) @@ -22,90 +14,10 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # reusable read-only placeholder list for speculative decoding. self._spec_token_placeholders: list[int] = [-1] * self.num_spec_tokens - # Dynamic SD: effective K for the current step, updated from - # ModelRunnerOutput after each model execution. - self._dynamic_num_spec_tokens: int | None = None - - # Dynamic SD async scheduling state: - # optimal K learned from the latest model output, to be applied - # at the beginning of the next schedule() call. - self._pending_optimal_k: int | None = None - # req_id -> committed spec token count from the most recent - # _update_after_schedule. Used to correct accounting when - # the optimal K changes between scheduling and output processing. - self._in_flight_decode_req_k: dict[str, int] = {} - - def _update_placeholders_from_dynamic_sd(self, optimal_k: int | None) -> None: - """Update placeholder count based on Dynamic SD decision.""" - if optimal_k is None: - self._dynamic_num_spec_tokens = None - self._spec_token_placeholders = [-1] * self.num_spec_tokens - elif optimal_k != self._dynamic_num_spec_tokens: - self._dynamic_num_spec_tokens = optimal_k - self._spec_token_placeholders = [-1] * optimal_k - - def _apply_pending_dynamic_sd_update(self) -> None: - """Apply a deferred dynamic SD K change at the start of schedule(). - - When update_from_output() learns a new optimal K, it cannot - immediately correct the in-flight step's accounting because we - need to target the right set of requests (those committed in the - most recent _update_after_schedule, tracked in - _in_flight_decode_req_k). This method is called at the beginning - of schedule() so that: - 1. The in-flight step's over/under-committed accounting is fixed. - 2. request.spec_token_ids is updated before the scheduling loop - reads it. - 3. _spec_token_placeholders reflects the new K for the current - scheduling step. - """ - optimal_k = self._pending_optimal_k - if optimal_k is None: - return - - self._pending_optimal_k = None - self._update_placeholders_from_dynamic_sd(optimal_k) - - for req_id, committed_k in self._in_flight_decode_req_k.items(): - diff = committed_k - optimal_k - if diff <= 0: - # K stayed the same or increased; the in-flight step - # under-allocated (if anything) but we cannot retroactively - # add more spec tokens to a step already on the GPU. - # Just update spec_token_ids for the next scheduling step. - request = self.requests.get(req_id) - if request is not None and not request.is_finished(): - request.spec_token_ids = self._spec_token_placeholders - continue - - request = self.requests.get(req_id) - if request is None or request.is_finished(): - continue - - request.num_output_placeholders -= diff - request.num_computed_tokens -= diff - request.spec_token_ids = self._spec_token_placeholders - - self._in_flight_decode_req_k = {} - - def schedule(self) -> SchedulerOutput: - self._apply_pending_dynamic_sd_update() - return super().schedule() - - def update_from_output( - self, - scheduler_output: SchedulerOutput, - model_runner_output: ModelRunnerOutput, - ) -> dict[int, EngineCoreOutputs]: - optimal_k = model_runner_output.optimal_num_speculative_tokens - if optimal_k is not None: - self._pending_optimal_k = optimal_k - return super().update_from_output(scheduler_output, model_runner_output) def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None: super()._update_after_schedule(scheduler_output) spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens - self._in_flight_decode_req_k = {} for req_id in scheduler_output.num_scheduled_tokens: request = self.requests[req_id] if request.is_prefill_chunk: diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 0564156c24aa..b6009cef55d1 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -243,6 +243,10 @@ class SchedulerOutput: # preventing stale NaN/data from corrupting attention or SSM computation. new_block_ids_to_zero: list[int] | None = None + # Dynamic speculative decoding: optimal K chosen by scheduler. + # Number of spec tokens to schedule for the next step. + num_spec_tokens_to_schedule: int | None = None + @classmethod def make_empty(cls) -> "SchedulerOutput": return cls( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 9b069cbc0a68..5e82ae304011 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -58,6 +58,7 @@ from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus, StreamingUpdate from vllm.v1.spec_decode.metrics import SpecDecodingStats +from vllm.v1.spec_decode.dynamic.manager import DynamicSpeculativeDecodingManager from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.utils import record_function_or_nullcontext @@ -213,10 +214,17 @@ def __init__( speculative_config = vllm_config.speculative_config self.use_eagle = False self.num_spec_tokens = self.num_lookahead_tokens = 0 - self.spec_decoding_stats_all = None + self.dynamic_sd_manager = None if speculative_config: self.num_spec_tokens = speculative_config.num_speculative_tokens - self.spec_decoding_stats_all = SpecDecodingStats.new(self.num_spec_tokens) + # setup Dynamic Speculative Decoding + self.dynamic_sd_manager: DynamicSpeculativeDecodingManager | None = None + if speculative_config.dynamic_config: + self.dynamic_sd_manager = DynamicSpeculativeDecodingManager( + speculative_config.dynamic_config, + self.scheduler_config.max_num_seqs, + self.num_spec_tokens, + ) if speculative_config.use_eagle(): self.use_eagle = True self.num_lookahead_tokens = self.num_spec_tokens @@ -906,6 +914,13 @@ def schedule(self) -> SchedulerOutput: else None ) + # Dynamic speculative decoding: compute optimal K + num_spec_tokens_to_schedule = None + if self.dynamic_sd_manager is not None and len(num_scheduled_tokens) > 0: + num_spec_tokens_to_schedule = self.dynamic_sd_manager.step( + len(num_scheduled_tokens) + ) + scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -923,6 +938,7 @@ def schedule(self) -> SchedulerOutput: free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), spec_decoding_stats_all=self.spec_decoding_stats_all, new_block_ids_to_zero=new_block_ids_to_zero, + num_spec_tokens_to_schedule=num_spec_tokens_to_schedule, ) # NOTE(Kuntai): this function is designed for multiple purposes: @@ -1978,6 +1994,11 @@ def make_spec_decoding_stats( num_invalid_spec_tokens: dict[str, int] | None, request_id: str, ) -> SpecDecodingStats | None: + if num_invalid_spec_tokens: + num_draft_tokens -= num_invalid_spec_tokens.get(request_id, 0) + if self.dynamic_sd_manager is not None and num_draft_tokens: + self.dynamic_sd_manager.observe_draft(num_draft_tokens, num_accepted_tokens) + # Save this so its accessible by scheduler and can # be sent to engine for Dynamic SD. if self.spec_decoding_stats_all is not None: @@ -1990,8 +2011,6 @@ def make_spec_decoding_stats( return None if spec_decoding_stats is None: spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens) - if num_invalid_spec_tokens: - num_draft_tokens -= num_invalid_spec_tokens.get(request_id, 0) spec_decoding_stats.observe_draft( num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens ) diff --git a/vllm/v1/spec_decode/dynamic/__init__.py b/vllm/v1/spec_decode/dynamic/__init__.py new file mode 100644 index 000000000000..0fec1fe5bcdf --- /dev/null +++ b/vllm/v1/spec_decode/dynamic/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project \ No newline at end of file diff --git a/vllm/v1/spec_decode/dynamic/manager.py b/vllm/v1/spec_decode/dynamic/manager.py index 1a983c57a631..1637966c756b 100644 --- a/vllm/v1/spec_decode/dynamic/manager.py +++ b/vllm/v1/spec_decode/dynamic/manager.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.config.speculative import DynamicSpeculativeConfig +from vllm.v1.spec_decode.metrics import SpecDecodingStats class DynamicSpeculativeDecodingManager: @@ -26,6 +27,9 @@ def __init__( self.steps = 0 self.warmup_steps = warmup_steps + # Cumulative stats for online acceptance rate updates + self.stats = SpecDecodingStats.new(vllm_num_speculative_tokens) + # Sanity check assert ( vllm_num_speculative_tokens @@ -177,3 +181,7 @@ def _compute_optimal_num_speculative_tokens(self, batch_size: int) -> int: chosen_num_drafts = num_drafts return chosen_num_drafts + + def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int) -> None: + """Record draft/accept counts for online acceptance rate updates.""" + self.stats.observe_draft(num_draft_tokens, num_accepted_tokens) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4fcf4e5a429a..c4a7d1168e6f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -590,15 +590,6 @@ def __init__( self.use_async_scheduling and self.num_spec_tokens > 0 ) - # setup Dynamic Speculative Decoding - self.dynamic_sd_manager: DynamicSpeculativeDecodingManager | None = None - if self.speculative_config.dynamic_config: - self.dynamic_sd_manager = DynamicSpeculativeDecodingManager( - self.speculative_config.dynamic_config, - self.vllm_config.scheduler_config.max_num_seqs, - self.speculative_config.num_speculative_tokens, - ) - # Request states. self.requests: dict[str, CachedRequestState] = {} # NOTE(rob): num_prompt_logprobs only includes reqs @@ -807,8 +798,7 @@ def __init__( self.runner_only_attn_layers: set[str] = set() # Cached outputs. - self._draft_token_ids: list[list[int]] | torch.Tensor | None = None - self._optimal_num_speculative_tokens: int | None = None + self._draft_token_ids: list[list[int]] | torch.Tensor | None = None # N-gram GPU path: async D2H buffer/event for per-request valid draft counts. self._num_valid_draft_tokens: torch.Tensor | None = None self._num_valid_draft_tokens_cpu: torch.Tensor | None = None @@ -1206,16 +1196,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> Callable | None req_data = scheduler_output.scheduled_cached_reqs scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens - # When dynamic SD chose a smaller K in the previous step, trim the - # scheduler output in-place so that input preparation, rejection - # accounting, and prev_num_draft_len all reflect the actual K. - if ( - self._optimal_num_speculative_tokens is not None - and self.use_async_scheduling - and scheduled_spec_tokens - ): - self._trim_spec_tokens_for_dynamic_sd(scheduler_output) - # Save scheduler-allocated spec lengths before trimming so # prev_num_draft_len keeps the optimistic count for rejection correction. original_num_spec_per_req: dict[str, int] = {} @@ -4342,7 +4322,6 @@ def propose_draft_token_ids(sampled_token_ids): else None, num_nans_in_logits=num_nans_in_logits, cudagraph_stats=cudagraph_stats, - optimal_num_speculative_tokens=(self._optimal_num_speculative_tokens), ) if not self.use_async_scheduling: @@ -4410,43 +4389,10 @@ def _pp_receive_prev_sampled_token_ids_to_input_batch(self) -> None: req_state.output_token_ids.append(-1) self.input_batch.prev_req_id_to_index = prev_req_id_to_index - def _trim_spec_tokens_for_dynamic_sd( - self, scheduler_output: "SchedulerOutput" - ) -> None: - """Trim scheduled spec tokens to match dynamic SD's optimal K. - - Called in _update_states when async scheduling is active and the - previous step determined a lower optimal K than what the scheduler - allocated. Modifies scheduler_output in-place so that the engine - core's update_from_output sees the trimmed counts in its rejection - logic, avoiding double-correction with the scheduler-side fix. - """ - k = self._optimal_num_speculative_tokens - assert k is not None - spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens - for req_id in list(spec_decode_tokens): - tokens = spec_decode_tokens[req_id] - scheduled_k = len(tokens) - if scheduled_k <= k: - continue - tokens_to_trim = scheduled_k - k - scheduler_output.total_num_scheduled_tokens -= tokens_to_trim - scheduler_output.num_scheduled_tokens[req_id] -= tokens_to_trim - if k == 0: - spec_decode_tokens.pop(req_id) - else: - spec_decode_tokens[req_id] = tokens[:k] - def take_draft_token_ids(self) -> DraftTokenIds | None: if not self.num_spec_tokens or not self._draft_token_req_ids: return None draft_token_ids, req_ids = self._get_draft_token_ids_cpu() - # When Dynamic SD reduced the number of speculative tokens, - # the GPU tensor was zero-padded to num_spec_tokens for scatter - # indexing, but the scheduler should only see the real draft tokens. - k = self._optimal_num_speculative_tokens - if k is not None and k < self.num_spec_tokens: - draft_token_ids = [ids[:k] for ids in draft_token_ids] return DraftTokenIds(req_ids, draft_token_ids) def _copy_draft_token_ids_to_cpu( @@ -4539,28 +4485,17 @@ def propose_draft_token_ids( common_attn_metadata: CommonAttentionMetadata, slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None, ) -> list[list[int]] | torch.Tensor: + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens spec_config = self.speculative_config assert spec_config is not None - if self.dynamic_sd_manager: - num_speculative_tokens = self.dynamic_sd_manager.step( - scheduler_output.spec_decoding_stats_all, - self.input_batch.num_reqs, - ) - else: - num_speculative_tokens = spec_config.num_speculative_tokens - self._optimal_num_speculative_tokens = ( - num_speculative_tokens if self.dynamic_sd_manager else None - ) - - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if spec_config.method == "ngram": from vllm.v1.spec_decode.ngram_proposer import NgramProposer assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.drafter.propose( - num_speculative_tokens, + scheduler_output.num_spec_tokens_to_schedule, sampled_token_ids, self.input_batch.num_tokens_no_spec, self.input_batch.token_ids_cpu, @@ -4586,7 +4521,7 @@ def propose_draft_token_ids( batch_size = next_token_ids.shape[0] draft_token_ids, num_valid_draft_tokens = self.drafter.propose( - num_speculative_tokens, + scheduler_output.num_spec_tokens_to_schedule, self.num_tokens_no_spec_gpu[:batch_size], self.token_ids_gpu_tensor[:batch_size], valid_sampled_token_ids_gpu, @@ -4608,7 +4543,7 @@ def propose_draft_token_ids( assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, SuffixDecodingProposer) draft_token_ids = self.drafter.propose( - num_speculative_tokens, + scheduler_output.num_spec_tokens_to_schedule, self.input_batch, sampled_token_ids, slot_mappings=slot_mappings, @@ -4635,7 +4570,7 @@ def propose_draft_token_ids( hidden_states = sample_hidden_states[indices] draft_token_ids = self.drafter.propose( - num_speculative_tokens=num_speculative_tokens, + num_speculative_tokens=scheduler_output.num_spec_tokens_to_schedule, target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, slot_mappings=slot_mappings, @@ -4653,7 +4588,7 @@ def propose_draft_token_ids( target_hidden_states = [h[:num_scheduled_tokens] for h in aux_hidden_states] draft_token_ids = self.drafter.propose( - num_speculative_tokens=num_speculative_tokens, + num_speculative_tokens=scheduler_output.num_spec_tokens_to_schedule, sampled_token_ids=sampled_token_ids, target_hidden_states=target_hidden_states, common_attn_metadata=common_attn_metadata, @@ -4778,7 +4713,7 @@ def propose_draft_token_ids( mm_embed_inputs = None draft_token_ids = self.drafter.propose( - num_speculative_tokens=num_speculative_tokens, + num_speculative_tokens=scheduler_output.num_spec_tokens_to_schedule, target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, From 3e94ad03cac7f06207bc14678692d5b866f0cee8 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Tue, 31 Mar 2026 03:50:18 +0000 Subject: [PATCH 36/39] move towards DSD scheduler Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/core/sched/async_scheduler.py | 3 --- vllm/v1/core/sched/output.py | 5 ----- vllm/v1/core/sched/scheduler.py | 9 --------- 3 files changed, 17 deletions(-) diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py index 32992aaf9e0d..0b3958dbcf5a 100644 --- a/vllm/v1/core/sched/async_scheduler.py +++ b/vllm/v1/core/sched/async_scheduler.py @@ -34,9 +34,6 @@ def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None: # We will update the actual spec token ids in the worker process. request.spec_token_ids = self._spec_token_placeholders - if cur_num_spec_tokens > 0: - self._in_flight_decode_req_k[req_id] = cur_num_spec_tokens - def _update_request_with_output( self, request: Request, new_token_ids: list[int] ) -> tuple[list[int], bool]: diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index b6009cef55d1..32f63d32e8dd 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -5,8 +5,6 @@ from functools import cached_property from typing import TYPE_CHECKING -from vllm.v1.spec_decode.metrics import SpecDecodingStats - if TYPE_CHECKING: import numpy as np import numpy.typing as npt @@ -235,9 +233,6 @@ class SchedulerOutput: # EC Cache Connector metadata ec_connector_metadata: ECConnectorMetadata | None = None - # Spec Decoding stats for all requests. - spec_decoding_stats_all: SpecDecodingStats | None = None - # Block IDs freshly allocated from the pool during this scheduling step. # The worker zeros the corresponding GPU memory before the blocks are used, # preventing stale NaN/data from corrupting attention or SSM computation. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 5e82ae304011..24ac1ec5a471 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -936,7 +936,6 @@ def schedule(self) -> SchedulerOutput: # the previous and the current steps. finished_req_ids=self.finished_req_ids, free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), - spec_decoding_stats_all=self.spec_decoding_stats_all, new_block_ids_to_zero=new_block_ids_to_zero, num_spec_tokens_to_schedule=num_spec_tokens_to_schedule, ) @@ -1999,14 +1998,6 @@ def make_spec_decoding_stats( if self.dynamic_sd_manager is not None and num_draft_tokens: self.dynamic_sd_manager.observe_draft(num_draft_tokens, num_accepted_tokens) - # Save this so its accessible by scheduler and can - # be sent to engine for Dynamic SD. - if self.spec_decoding_stats_all is not None: - self.spec_decoding_stats_all.observe_draft( - num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted_tokens, - ) - if not self.log_stats or not num_draft_tokens: return None if spec_decoding_stats is None: From 8aa39fbb3ed5ef1fa471762d23852d40b814696e Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Wed, 1 Apr 2026 00:02:33 +0000 Subject: [PATCH 37/39] fix padded drafter Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/core/sched/async_scheduler.py | 4 +++ vllm/v1/core/sched/scheduler.py | 6 ++--- vllm/v1/outputs.py | 4 --- vllm/v1/spec_decode/dynamic/__init__.py | 2 +- vllm/v1/spec_decode/dynamic/manager.py | 14 +++++----- vllm/v1/worker/gpu_model_runner.py | 36 +++++++++---------------- 6 files changed, 27 insertions(+), 39 deletions(-) diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py index 0b3958dbcf5a..dcb1e34f1c36 100644 --- a/vllm/v1/core/sched/async_scheduler.py +++ b/vllm/v1/core/sched/async_scheduler.py @@ -18,6 +18,10 @@ def __init__(self, *args, **kwargs) -> None: def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None: super()._update_after_schedule(scheduler_output) spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens + # Use the latest num of draft tokens to schedule in the next step as placeholder. + self._spec_token_placeholders = [ + -1 + ] * scheduler_output.num_spec_tokens_to_schedule for req_id in scheduler_output.num_scheduled_tokens: request = self.requests[req_id] if request.is_prefill_chunk: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 24ac1ec5a471..d1e98b4280ae 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -57,8 +57,8 @@ from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus, StreamingUpdate -from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.spec_decode.dynamic.manager import DynamicSpeculativeDecodingManager +from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.utils import record_function_or_nullcontext @@ -915,7 +915,7 @@ def schedule(self) -> SchedulerOutput: ) # Dynamic speculative decoding: compute optimal K - num_spec_tokens_to_schedule = None + num_spec_tokens_to_schedule = self.num_spec_tokens if self.dynamic_sd_manager is not None and len(num_scheduled_tokens) > 0: num_spec_tokens_to_schedule = self.dynamic_sd_manager.step( len(num_scheduled_tokens) @@ -1997,7 +1997,7 @@ def make_spec_decoding_stats( num_draft_tokens -= num_invalid_spec_tokens.get(request_id, 0) if self.dynamic_sd_manager is not None and num_draft_tokens: self.dynamic_sd_manager.observe_draft(num_draft_tokens, num_accepted_tokens) - + if not self.log_stats or not num_draft_tokens: return None if spec_decoding_stats is None: diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 221c705c2142..8eb58de4f3fd 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -253,10 +253,6 @@ class ModelRunnerOutput: # information related to cudagraph execution cudagraph_stats: CUDAGraphStat | None = None - # Dynamic Speculative Decoding: optimal K chosen for this step. - # None means DSD is not active (use the static num_spec_tokens). - optimal_num_speculative_tokens: int | None = None - # ModelRunnerOutput wrapper for async scheduling. class AsyncModelRunnerOutput(ABC): diff --git a/vllm/v1/spec_decode/dynamic/__init__.py b/vllm/v1/spec_decode/dynamic/__init__.py index 0fec1fe5bcdf..208f01a7cb5e 100644 --- a/vllm/v1/spec_decode/dynamic/__init__.py +++ b/vllm/v1/spec_decode/dynamic/__init__.py @@ -1,2 +1,2 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project \ No newline at end of file +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm/v1/spec_decode/dynamic/manager.py b/vllm/v1/spec_decode/dynamic/manager.py index 1637966c756b..a32ba7282cfa 100644 --- a/vllm/v1/spec_decode/dynamic/manager.py +++ b/vllm/v1/spec_decode/dynamic/manager.py @@ -66,12 +66,10 @@ def __init__( self.update_optimal_num_speculative_tokens() - def step(self, spec_decoding_stats_all, batch_size: int) -> int: + def step(self, batch_size: int) -> int: self.steps += 1 if self.should_update(): - acceptance_rate_per_pos = self.compute_acceptance_rate_per_pos( - spec_decoding_stats_all - ) + acceptance_rate_per_pos = self.compute_acceptance_rate_per_pos() self.update_acceptance_rate_per_pos(acceptance_rate_per_pos) optimal_num_speculative_tokens = self.get_optimal_num_speculative_tokens( @@ -80,15 +78,15 @@ def step(self, spec_decoding_stats_all, batch_size: int) -> int: return optimal_num_speculative_tokens - def compute_acceptance_rate_per_pos(self, spec_decoding_stats_all) -> list[float]: + def compute_acceptance_rate_per_pos(self) -> list[float]: acceptance_rate_per_pos = [] for i in range(self.vllm_num_speculative_tokens): - if spec_decoding_stats_all.num_draft_tokens_per_pos[i] == 0: + if self.stats.num_draft_tokens_per_pos[i] == 0: acceptance_rate = 0.0 else: acceptance_rate = ( - spec_decoding_stats_all.num_accepted_tokens_per_pos[i] - / spec_decoding_stats_all.num_draft_tokens_per_pos[i] + self.stats.num_accepted_tokens_per_pos[i] + / self.stats.num_draft_tokens_per_pos[i] ) acceptance_rate_per_pos.append(acceptance_rate) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c4a7d1168e6f..a5899f19d856 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -163,7 +163,6 @@ from vllm.v1.sample.sampler import Sampler from vllm.v1.spec_decode.dflash import DFlashProposer from vllm.v1.spec_decode.draft_model import DraftModelProposer -from vllm.v1.spec_decode.dynamic.manager import DynamicSpeculativeDecodingManager from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer from vllm.v1.spec_decode.medusa import MedusaProposer @@ -578,9 +577,11 @@ def __init__( self.rejection_sampler = RejectionSampler(self.sampler) self.num_spec_tokens = 0 + self.prev_num_spec_tokens = 0 self.valid_sampled_token_count_gpu: torch.Tensor | None = None if self.speculative_config: self.num_spec_tokens = self.speculative_config.num_speculative_tokens + self.prev_num_spec_tokens = self.num_spec_tokens draft_config = self.speculative_config.draft_model_config if draft_config is not None and draft_config.max_model_len is not None: self.effective_drafter_max_model_len = draft_config.max_model_len @@ -798,7 +799,7 @@ def __init__( self.runner_only_attn_layers: set[str] = set() # Cached outputs. - self._draft_token_ids: list[list[int]] | torch.Tensor | None = None + self._draft_token_ids: list[list[int]] | torch.Tensor | None = None # N-gram GPU path: async D2H buffer/event for per-request valid draft counts. self._num_valid_draft_tokens: torch.Tensor | None = None self._num_valid_draft_tokens_cpu: torch.Tensor | None = None @@ -1670,7 +1671,7 @@ def _prepare_input_ids( spec_flattened_indices.extend( range(flattened_index - draft_len + 1, flattened_index + 1) ) - start = prev_index * self.num_spec_tokens + start = prev_index * self.prev_num_spec_tokens # prev_draft_token_indices is used to find which draft_tokens_id # should be copied to input_ids # example: prev draft_tokens_id [[1,2], [3,4], [5, 6]] @@ -4257,8 +4258,6 @@ def propose_draft_token_ids(sampled_token_ids): self._copy_valid_sampled_token_count( next_token_ids, valid_sampled_tokens_count ) - # Since we couldn't run the drafter, - # just use zeros for the draft tokens. self._draft_token_ids = torch.zeros( 1, device=self.device, dtype=torch.int32 ).expand(len(self.input_batch.req_ids), self.num_spec_tokens) @@ -4398,6 +4397,8 @@ def take_draft_token_ids(self) -> DraftTokenIds | None: def _copy_draft_token_ids_to_cpu( self, scheduler_output: "SchedulerOutput", zeros_only: bool = False ) -> None: + if torch.is_tensor(self._draft_token_ids): + self.prev_num_spec_tokens = self._draft_token_ids.shape[1] # Check if we need to copy draft tokens to CPU. In async scheduling, # we only copy when needed for structured output, penalties or bad_words. if self.use_async_scheduling and not ( @@ -4416,16 +4417,17 @@ def _copy_draft_token_ids_to_cpu( assert self.draft_token_ids_cpu is not None default_stream = torch.cuda.current_stream() num_reqs = draft_token_ids.shape[0] + num_spec_tokens = draft_token_ids.shape[1] with torch.cuda.stream(self.draft_token_ids_copy_stream): if not zeros_only: # Trigger async copy of draft token ids to cpu. self.draft_token_ids_copy_stream.wait_stream(default_stream) - self.draft_token_ids_cpu[:num_reqs].copy_( + self.draft_token_ids_cpu[:num_reqs, :num_spec_tokens].copy_( draft_token_ids, non_blocking=True ) else: # No copy needed, just zero-out cpu tensor. - self.draft_token_ids_cpu[:num_reqs] = 0 + self.draft_token_ids_cpu[:num_reqs, :num_spec_tokens] = 0 self.draft_token_ids_event.record() def _get_draft_token_ids_cpu(self) -> tuple[list[list[int]], list[str]]: @@ -4437,7 +4439,10 @@ def _get_draft_token_ids_cpu(self) -> tuple[list[list[int]], list[str]]: assert self.draft_token_ids_event is not None assert self.draft_token_ids_cpu is not None self.draft_token_ids_event.synchronize() - return self.draft_token_ids_cpu[: len(req_ids)].tolist(), req_ids + num_spec_tokens = self._draft_token_ids.shape[1] + return self.draft_token_ids_cpu[ + : len(req_ids), :num_spec_tokens + ].tolist(), req_ids def _copy_valid_sampled_token_count( self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor @@ -4488,7 +4493,6 @@ def propose_draft_token_ids( num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens spec_config = self.speculative_config assert spec_config is not None - if spec_config.method == "ngram": from vllm.v1.spec_decode.ngram_proposer import NgramProposer @@ -4726,20 +4730,6 @@ def propose_draft_token_ids( slot_mappings=slot_mappings, ) - if ( - self.dynamic_sd_manager - and isinstance(draft_token_ids, torch.Tensor) - and draft_token_ids.dim() == 2 - and draft_token_ids.shape[1] < self.num_spec_tokens - ): - padding = torch.zeros( - draft_token_ids.shape[0], - self.num_spec_tokens - draft_token_ids.shape[1], - device=draft_token_ids.device, - dtype=draft_token_ids.dtype, - ) - draft_token_ids = torch.cat([draft_token_ids, padding], dim=1) - return draft_token_ids def update_config(self, overrides: dict[str, Any]) -> None: From 272409768725300b5fb9de9a6ea202dc7c9f9f69 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Wed, 1 Apr 2026 00:16:16 +0000 Subject: [PATCH 38/39] lint Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/config/speculative.py | 6 +++--- vllm/v1/core/sched/async_scheduler.py | 2 +- vllm/v1/core/sched/output.py | 2 +- vllm/v1/core/sched/scheduler.py | 4 +--- vllm/v1/worker/gpu_model_runner.py | 16 ++++++++++------ 5 files changed, 16 insertions(+), 14 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 2beeacbf7d32..114c61dc6f03 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -73,7 +73,7 @@ class DynamicSpeculativeConfig: is_online: bool = False """Whether the statistics are updated online or not during inference.""" - batch_stats: dict[int, dict[int, float]] = None + batch_stats: dict[int, dict[int, float]] | None = None """ Batch statistics for different batch sizes and number of drafts. The structure is as follows: @@ -92,10 +92,10 @@ class DynamicSpeculativeConfig: where bs 1 at K=3 has itl 9.41ms. K=0 means no speculative decoding. """ - max_num_speculative_tokens: int = None + max_num_speculative_tokens: int | None = None """Maximum number of speculative tokens supported in the statistics.""" - acceptance_rate_per_pos: list[float] = None + acceptance_rate_per_pos: list[float] | None = None """Acceptance rate per position on an offline dataset.""" diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py index dcb1e34f1c36..381b758870ff 100644 --- a/vllm/v1/core/sched/async_scheduler.py +++ b/vllm/v1/core/sched/async_scheduler.py @@ -18,7 +18,7 @@ def __init__(self, *args, **kwargs) -> None: def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None: super()._update_after_schedule(scheduler_output) spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens - # Use the latest num of draft tokens to schedule in the next step as placeholder. + # Use the latest num of scheduled draft tokens in next step as placeholder. self._spec_token_placeholders = [ -1 ] * scheduler_output.num_spec_tokens_to_schedule diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 32f63d32e8dd..08e26476ac34 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -240,7 +240,7 @@ class SchedulerOutput: # Dynamic speculative decoding: optimal K chosen by scheduler. # Number of spec tokens to schedule for the next step. - num_spec_tokens_to_schedule: int | None = None + num_spec_tokens_to_schedule: int = 0 @classmethod def make_empty(cls) -> "SchedulerOutput": diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d1e98b4280ae..18f51691a160 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -214,11 +214,9 @@ def __init__( speculative_config = vllm_config.speculative_config self.use_eagle = False self.num_spec_tokens = self.num_lookahead_tokens = 0 - self.dynamic_sd_manager = None + self.dynamic_sd_manager: DynamicSpeculativeDecodingManager | None = None if speculative_config: self.num_spec_tokens = speculative_config.num_speculative_tokens - # setup Dynamic Speculative Decoding - self.dynamic_sd_manager: DynamicSpeculativeDecodingManager | None = None if speculative_config.dynamic_config: self.dynamic_sd_manager = DynamicSpeculativeDecodingManager( speculative_config.dynamic_config, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a5899f19d856..32dd232fff65 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4398,6 +4398,7 @@ def _copy_draft_token_ids_to_cpu( self, scheduler_output: "SchedulerOutput", zeros_only: bool = False ) -> None: if torch.is_tensor(self._draft_token_ids): + assert isinstance(self._draft_token_ids, torch.Tensor) self.prev_num_spec_tokens = self._draft_token_ids.shape[1] # Check if we need to copy draft tokens to CPU. In async scheduling, # we only copy when needed for structured output, penalties or bad_words. @@ -4439,6 +4440,7 @@ def _get_draft_token_ids_cpu(self) -> tuple[list[list[int]], list[str]]: assert self.draft_token_ids_event is not None assert self.draft_token_ids_cpu is not None self.draft_token_ids_event.synchronize() + assert isinstance(self._draft_token_ids, torch.Tensor) num_spec_tokens = self._draft_token_ids.shape[1] return self.draft_token_ids_cpu[ : len(req_ids), :num_spec_tokens @@ -4493,13 +4495,15 @@ def propose_draft_token_ids( num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens spec_config = self.speculative_config assert spec_config is not None + num_spec_tokens_to_schedule = ( + scheduler_output.num_spec_tokens_to_schedule) if spec_config.method == "ngram": from vllm.v1.spec_decode.ngram_proposer import NgramProposer assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.drafter.propose( - scheduler_output.num_spec_tokens_to_schedule, + num_spec_tokens_to_schedule, sampled_token_ids, self.input_batch.num_tokens_no_spec, self.input_batch.token_ids_cpu, @@ -4525,7 +4529,7 @@ def propose_draft_token_ids( batch_size = next_token_ids.shape[0] draft_token_ids, num_valid_draft_tokens = self.drafter.propose( - scheduler_output.num_spec_tokens_to_schedule, + num_spec_tokens_to_schedule, self.num_tokens_no_spec_gpu[:batch_size], self.token_ids_gpu_tensor[:batch_size], valid_sampled_token_ids_gpu, @@ -4547,7 +4551,7 @@ def propose_draft_token_ids( assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, SuffixDecodingProposer) draft_token_ids = self.drafter.propose( - scheduler_output.num_spec_tokens_to_schedule, + num_spec_tokens_to_schedule, self.input_batch, sampled_token_ids, slot_mappings=slot_mappings, @@ -4574,7 +4578,7 @@ def propose_draft_token_ids( hidden_states = sample_hidden_states[indices] draft_token_ids = self.drafter.propose( - num_speculative_tokens=scheduler_output.num_spec_tokens_to_schedule, + num_speculative_tokens=num_spec_tokens_to_schedule, target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, slot_mappings=slot_mappings, @@ -4592,7 +4596,7 @@ def propose_draft_token_ids( target_hidden_states = [h[:num_scheduled_tokens] for h in aux_hidden_states] draft_token_ids = self.drafter.propose( - num_speculative_tokens=scheduler_output.num_spec_tokens_to_schedule, + num_speculative_tokens=num_spec_tokens_to_schedule, sampled_token_ids=sampled_token_ids, target_hidden_states=target_hidden_states, common_attn_metadata=common_attn_metadata, @@ -4717,7 +4721,7 @@ def propose_draft_token_ids( mm_embed_inputs = None draft_token_ids = self.drafter.propose( - num_speculative_tokens=scheduler_output.num_spec_tokens_to_schedule, + num_speculative_tokens=num_spec_tokens_to_schedule, target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, From e57e624dbe90b90e8a70bee9274238bdf82d6317 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Wed, 1 Apr 2026 00:21:32 +0000 Subject: [PATCH 39/39] lint Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/spec_decode/dynamic/manager.py | 55 ++++++++++++++++---------- vllm/v1/worker/gpu_model_runner.py | 3 +- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/vllm/v1/spec_decode/dynamic/manager.py b/vllm/v1/spec_decode/dynamic/manager.py index a32ba7282cfa..6c49fea29cad 100644 --- a/vllm/v1/spec_decode/dynamic/manager.py +++ b/vllm/v1/spec_decode/dynamic/manager.py @@ -22,8 +22,23 @@ def __init__( self.dynamic_config = dynamic_config self.vllm_max_batch_size = vllm_max_batch_size self.vllm_num_speculative_tokens = vllm_num_speculative_tokens - self.batch_stats = self.dynamic_config.batch_stats - self.available_batch_sizes = sorted(self.dynamic_config.batch_stats.keys()) + + assert dynamic_config.batch_stats is not None, ( + "batch_stats is required for dynamic speculative decoding" + ) + assert dynamic_config.max_num_speculative_tokens is not None, ( + "max_num_speculative_tokens is required for dynamic speculative decoding" + ) + assert dynamic_config.acceptance_rate_per_pos is not None, ( + "acceptance_rate_per_pos is required for dynamic speculative decoding" + ) + + self.batch_stats: dict[int, dict[int, float]] = dynamic_config.batch_stats + self.max_num_speculative_tokens: int = dynamic_config.max_num_speculative_tokens + self.acceptance_rate_per_pos: list[float] = ( + dynamic_config.acceptance_rate_per_pos + ) + self.available_batch_sizes = sorted(dynamic_config.batch_stats.keys()) self.steps = 0 self.warmup_steps = warmup_steps @@ -32,36 +47,35 @@ def __init__( # Sanity check assert ( - vllm_num_speculative_tokens - <= self.dynamic_config.max_num_speculative_tokens + vllm_num_speculative_tokens <= dynamic_config.max_num_speculative_tokens ), "vllm_num_speculative_tokens must be <= max_num_speculative_tokens" - assert self.dynamic_config.max_num_speculative_tokens == len( - self.dynamic_config.acceptance_rate_per_pos + assert dynamic_config.max_num_speculative_tokens == len( + dynamic_config.acceptance_rate_per_pos ), "max_num_speculative_tokens != len(acceptance_rate_per_pos)" - assert self.dynamic_config.max_num_speculative_tokens > 0, ( + assert dynamic_config.max_num_speculative_tokens > 0, ( "max_num_speculative_tokens must be > 0" ) - assert all( - 0.0 <= a <= 1.0 for a in self.dynamic_config.acceptance_rate_per_pos - ), "all acceptance_rate_per_pos values must be in [0.0, 1.0]" - assert 1 in self.dynamic_config.batch_stats, ( - f"BS 1 not found in {self.dynamic_config.batch_stats.keys()}" + assert all(0.0 <= a <= 1.0 for a in dynamic_config.acceptance_rate_per_pos), ( + "all acceptance_rate_per_pos values must be in [0.0, 1.0]" + ) + assert 1 in dynamic_config.batch_stats, ( + f"BS 1 not found in {dynamic_config.batch_stats.keys()}" ) - assert vllm_max_batch_size in self.dynamic_config.batch_stats, ( - f"max BS not found in {self.dynamic_config.batch_stats.keys()}" + assert vllm_max_batch_size in dynamic_config.batch_stats, ( + f"max BS not found in {dynamic_config.batch_stats.keys()}" ) for bs in self.available_batch_sizes: assert bs > 0 - assert 0 in self.dynamic_config.batch_stats[bs], ( + assert 0 in dynamic_config.batch_stats[bs], ( f"batch size {bs} must have draft 0 stats" ) - assert 1 in self.dynamic_config.batch_stats[bs], ( + assert 1 in dynamic_config.batch_stats[bs], ( f"batch size {bs} must have draft 1 stats" ) - assert sorted(self.dynamic_config.batch_stats[bs].keys()) == list( - self.dynamic_config.batch_stats[bs].keys() + assert sorted(dynamic_config.batch_stats[bs].keys()) == list( + dynamic_config.batch_stats[bs].keys() ), f"batch size {bs} draft keys must be sorted" self.update_optimal_num_speculative_tokens() @@ -111,6 +125,7 @@ def update_optimal_num_speculative_tokens(self): } def update_acceptance_rate_per_pos(self, acceptance_rate_per_pos: list[float]): + self.acceptance_rate_per_pos = acceptance_rate_per_pos self.dynamic_config.acceptance_rate_per_pos = acceptance_rate_per_pos self.update_optimal_num_speculative_tokens() @@ -170,8 +185,8 @@ def _compute_optimal_num_speculative_tokens(self, batch_size: int) -> int: batch_stats = self._get_batch_stats(batch_size) max_goodput = -1.0 - for num_drafts in range(self.dynamic_config.max_num_speculative_tokens + 1): - curr_al = 1 + sum(self.dynamic_config.acceptance_rate_per_pos[:num_drafts]) + for num_drafts in range(self.max_num_speculative_tokens + 1): + curr_al = 1 + sum(self.acceptance_rate_per_pos[:num_drafts]) curr_itl = self._get_itl(batch_stats, num_drafts) curr_goodput = curr_al / curr_itl if curr_goodput > max_goodput: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 32dd232fff65..34c5fb8731fe 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4495,8 +4495,7 @@ def propose_draft_token_ids( num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens spec_config = self.speculative_config assert spec_config is not None - num_spec_tokens_to_schedule = ( - scheduler_output.num_spec_tokens_to_schedule) + num_spec_tokens_to_schedule = scheduler_output.num_spec_tokens_to_schedule if spec_config.method == "ngram": from vllm.v1.spec_decode.ngram_proposer import NgramProposer