-
-
Notifications
You must be signed in to change notification settings - Fork 18.7k
[V1][Spec Decode] Add Dynamic SD #32374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 9 commits
Commits
Show all changes
66 commits
Select commit
Hold shift + click to select a range
8e7bcc9
save work
ekagra-ranjan bee764b
move to dynamic folder and scripts are in working condition
ekagra-ranjan e61145f
sync main
ekagra-ranjan 06916e5
pre
ekagra-ranjan 5e63d51
start stiching
ekagra-ranjan b683f26
pipeline works and offline script moved
ekagra-ranjan 618b5fd
load dynamic sd config
ekagra-ranjan c540a75
remove offline bkp
ekagra-ranjan dfb2b31
remove
ekagra-ranjan bb65365
add runtime AL to goodput after warmup
ekagra-ranjan 5fcf59e
Update vllm/v1/spec_decode/dynamic/manager.py
ekagra-ranjan 409eb69
revert offline decoder to save loc diff
ekagra-ranjan d7a149f
refactor
ekagra-ranjan f222372
conflict
ekagra-ranjan 44aed5e
conflict
ekagra-ranjan 7ed3353
refactor
ekagra-ranjan 116e76b
add timeout
ekagra-ranjan d70bd1d
fix
ekagra-ranjan 24304d5
reduce loc in favor of #34105
ekagra-ranjan 8fb86b1
remove test from dynamic manager main()
ekagra-ranjan 11d43a5
remove comment and fix lint
ekagra-ranjan 5df4999
fix mypy
ekagra-ranjan e76ad8e
fix mypy
ekagra-ranjan c1e880b
lint
ekagra-ranjan 72d3c6f
lint
ekagra-ranjan 63c7e17
conflict
ekagra-ranjan 534d2f2
add AL computation to generate_config
ekagra-ranjan 3f7196e
fix padding for async sched
ekagra-ranjan 64fab8d
Update vllm/config/speculative.py
ekagra-ranjan 1198aa7
fix docstring
ekagra-ranjan c07afd1
lint
ekagra-ranjan 594dc0f
make DSD compat with async and padded drafter
ekagra-ranjan 7060a8b
Merge branch 'main' of https://github.com/vllm-project/vllm into er-d…
ekagra-ranjan cd19750
lint
ekagra-ranjan 4307feb
optimize DSD async scheduling by minimizing delay in propagating opti…
ekagra-ranjan 018b4bd
dsd config path field
ekagra-ranjan 36f5a36
refactor to simplify propose signature and update test
ekagra-ranjan c54ac4c
conflict
ekagra-ranjan 046c39a
fix comma
ekagra-ranjan b59662d
Merge branch 'main' of https://github.com/vllm-project/vllm into er-d…
ekagra-ranjan 917e3de
move towards DSD scheduler
ekagra-ranjan 3e94ad0
move towards DSD scheduler
ekagra-ranjan 8aa39fb
fix padded drafter
ekagra-ranjan 2724097
lint
ekagra-ranjan e57e624
lint
ekagra-ranjan d1601b3
conflict
ekagra-ranjan abb3a0b
works with offline profiler
ekagra-ranjan f7aa05c
remove offline profiler and use K manually
ekagra-ranjan 51d6eaf
conflict
ekagra-ranjan 8cdb086
lint
ekagra-ranjan 8331255
lint and pytest assert
ekagra-ranjan a1e2185
remove config abstraction
ekagra-ranjan 6954f68
revert noop
ekagra-ranjan 3c7aef8
remove DSD manager and update DynamicSDSchedule data formet
ekagra-ranjan 9fa1e9d
Merge branch 'main' of https://github.com/vllm-project/vllm into er-d…
ekagra-ranjan 815fd9c
lint
ekagra-ranjan 503935f
override DSD to piecewise
ekagra-ranjan b35f7c2
Merge branch 'main' of https://github.com/vllm-project/vllm into er-d…
ekagra-ranjan 5484d2c
disable mrv2 for dsd
ekagra-ranjan 6fa1d52
add doc
ekagra-ranjan 4b9ae13
add doc
ekagra-ranjan 417ddc4
add doc
ekagra-ranjan 065e846
conflict
ekagra-ranjan 5ea4979
@benchislett fix typo in doc
benchislett f4500b9
Merge branch 'main' into er-dynami-sd
ekagra-ranjan 3ce1639
Merge branch 'main' into er-dynami-sd
ekagra-ranjan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,234 +1,4 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| from transformers import AutoTokenizer | ||
|
|
||
| from vllm import LLM, SamplingParams | ||
| from vllm.benchmarks.datasets import add_dataset_parser, get_samples | ||
| from vllm.inputs import TokensPrompt | ||
| from vllm.v1.metrics.reader import Counter, Vector | ||
|
|
||
| try: | ||
| from vllm.utils.argparse_utils import FlexibleArgumentParser | ||
| except ImportError: | ||
| from argparse import ArgumentParser as FlexibleArgumentParser | ||
|
|
||
|
|
||
| QUESTION = "What is the content of each image?" | ||
| IMAGE_URLS = [ | ||
| "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/duck.jpg", | ||
| "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/lion.jpg", | ||
| "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/flycatcher.jpeg", | ||
| "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/somefish.jpg", | ||
| "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/starfish.jpg", | ||
| "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/snail.jpg", | ||
| "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/thistle.jpg", | ||
| "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/husky.jpg", | ||
| "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/orangetabbycat.jpg", | ||
| "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/guineapig.jpg", | ||
| "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/rabbit.jpg", | ||
| "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/horsepony.jpg", | ||
| ] | ||
|
|
||
|
|
||
| def get_custom_mm_prompts(num_prompts): | ||
| prompts = [] | ||
| for url in IMAGE_URLS: | ||
| prompts.append( | ||
| [ | ||
| {"type": "image_url", "image_url": {"url": url}}, | ||
| {"type": "text", "text": QUESTION}, | ||
| ] | ||
| ) | ||
| if num_prompts > len(IMAGE_URLS): | ||
| prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1) | ||
|
|
||
| return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]] | ||
|
|
||
|
|
||
| def parse_args(): | ||
| parser = FlexibleArgumentParser() | ||
| add_dataset_parser(parser) | ||
| parser.add_argument("--test", action="store_true") | ||
| parser.add_argument( | ||
| "--method", | ||
| type=str, | ||
| default="eagle", | ||
| choices=["ngram", "eagle", "eagle3", "mtp"], | ||
| ) | ||
| parser.add_argument("--num-spec-tokens", type=int, default=2) | ||
| parser.add_argument("--prompt-lookup-max", type=int, default=5) | ||
| parser.add_argument("--prompt-lookup-min", type=int, default=2) | ||
| parser.add_argument("--tp", type=int, default=1) | ||
| parser.add_argument("--enforce-eager", action="store_true") | ||
| parser.add_argument("--enable-chunked-prefill", action="store_true") | ||
| parser.add_argument("--max-model-len", type=int, default=16384) | ||
| parser.add_argument("--temp", type=float, default=0) | ||
| parser.add_argument("--top-p", type=float, default=1.0) | ||
| parser.add_argument("--top-k", type=int, default=-1) | ||
| parser.add_argument("--print-output", action="store_true") | ||
| parser.add_argument("--output-len", type=int, default=256) | ||
| parser.add_argument("--model-dir", type=str, default=None) | ||
| parser.add_argument("--eagle-dir", type=str, default=None) | ||
| parser.add_argument("--custom-mm-prompts", action="store_true") | ||
| return parser.parse_args() | ||
|
|
||
|
|
||
| def main(args): | ||
| args.endpoint_type = "openai-chat" | ||
|
|
||
| model_dir = args.model_dir | ||
| if args.model_dir is None: | ||
| if args.custom_mm_prompts: | ||
| raise ValueError( | ||
| "custom_mm_prompts requires mm based models" | ||
| "default llama3.1-8b-instruct is not mm based" | ||
| "please specify model_dir to give a mm based model" | ||
| ) | ||
| model_dir = "meta-llama/Llama-3.1-8B-Instruct" | ||
| tokenizer = AutoTokenizer.from_pretrained(model_dir) | ||
| args.custom_skip_chat_template = True | ||
|
|
||
| if not args.custom_mm_prompts: | ||
| prompts = get_samples(args, tokenizer) | ||
| # add_special_tokens is False to avoid adding bos twice | ||
| # when using chat templates | ||
| prompt_ids = [ | ||
| tokenizer.encode(prompt.prompt, add_special_tokens=False) | ||
| for prompt in prompts | ||
| ] | ||
| else: | ||
| prompts = get_custom_mm_prompts(args.num_prompts) | ||
|
|
||
| if args.method == "eagle" or args.method == "eagle3": | ||
| eagle_dir = args.eagle_dir | ||
| if args.method == "eagle" and eagle_dir is None: | ||
| eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" | ||
|
|
||
| elif args.method == "eagle3" and eagle_dir is None: | ||
| eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" | ||
| speculative_config = { | ||
| "method": args.method, | ||
| "model": eagle_dir, | ||
| "num_speculative_tokens": args.num_spec_tokens, | ||
| } | ||
| elif args.method == "ngram": | ||
| speculative_config = { | ||
| "method": "ngram", | ||
| "num_speculative_tokens": args.num_spec_tokens, | ||
| "prompt_lookup_max": args.prompt_lookup_max, | ||
| "prompt_lookup_min": args.prompt_lookup_min, | ||
| } | ||
| elif args.method == "mtp": | ||
| speculative_config = { | ||
| "method": "mtp", | ||
| "num_speculative_tokens": args.num_spec_tokens, | ||
| } | ||
| else: | ||
| raise ValueError(f"unknown method: {args.method}") | ||
|
|
||
| llm = LLM( | ||
| model=model_dir, | ||
| trust_remote_code=True, | ||
| tensor_parallel_size=args.tp, | ||
| enable_chunked_prefill=args.enable_chunked_prefill, | ||
| enforce_eager=args.enforce_eager, | ||
| gpu_memory_utilization=0.9, | ||
| speculative_config=speculative_config, | ||
| disable_log_stats=False, | ||
| max_model_len=args.max_model_len, | ||
| limit_mm_per_prompt={"image": 5}, | ||
| disable_chunked_mm_input=True, | ||
| ) | ||
|
|
||
| sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) | ||
| if not args.custom_mm_prompts: | ||
| outputs = llm.generate( | ||
| [TokensPrompt(prompt_token_ids=x) for x in prompt_ids], | ||
| sampling_params=sampling_params, | ||
| ) | ||
| else: | ||
| outputs = llm.chat(prompts, sampling_params=sampling_params) | ||
|
|
||
| # print the generated text | ||
| if args.print_output: | ||
| for output in outputs: | ||
| print("-" * 50) | ||
| print(f"prompt: {output.prompt}") | ||
| print(f"generated text: {output.outputs[0].text}") | ||
| print("-" * 50) | ||
|
|
||
| metrics = llm.get_metrics() | ||
|
|
||
| total_num_output_tokens = sum( | ||
| len(output.outputs[0].token_ids) for output in outputs | ||
| ) | ||
| num_drafts = 0 | ||
| num_draft_tokens = 0 | ||
| num_accepted_tokens = 0 | ||
| acceptance_counts = [0] * args.num_spec_tokens | ||
| for metric in metrics: | ||
| if metric.name == "vllm:spec_decode_num_drafts": | ||
| assert isinstance(metric, Counter) | ||
| num_drafts += metric.value | ||
| elif metric.name == "vllm:spec_decode_num_draft_tokens": | ||
| assert isinstance(metric, Counter) | ||
| num_draft_tokens += metric.value | ||
| elif metric.name == "vllm:spec_decode_num_accepted_tokens": | ||
| assert isinstance(metric, Counter) | ||
| num_accepted_tokens += metric.value | ||
| elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": | ||
| assert isinstance(metric, Vector) | ||
| for pos in range(len(metric.values)): | ||
| acceptance_counts[pos] += metric.values[pos] | ||
|
|
||
| print("-" * 50) | ||
| print(f"total_num_output_tokens: {total_num_output_tokens}") | ||
| print(f"num_drafts: {num_drafts}") | ||
| print(f"num_draft_tokens: {num_draft_tokens}") | ||
| print(f"num_accepted_tokens: {num_accepted_tokens}") | ||
| acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1 | ||
| print(f"mean acceptance length: {acceptance_length:.2f}") | ||
| print("-" * 50) | ||
|
|
||
| # print acceptance at each token position | ||
| for i in range(len(acceptance_counts)): | ||
| acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0 | ||
| print(f"acceptance at token {i}: {acceptance_rate:.2f}") | ||
|
|
||
| return acceptance_length | ||
|
|
||
| from vllm.v1.spec_decode.offline import entrypoint as spec_decode_main | ||
|
|
||
| if __name__ == "__main__": | ||
| args = parse_args() | ||
| acceptance_length = main(args) | ||
|
|
||
| if args.test: | ||
| # takes ~30s to run on 1xH100 | ||
| assert args.method in ["eagle", "eagle3"] | ||
| assert args.tp == 1 | ||
| assert args.num_spec_tokens == 3 | ||
| assert args.dataset_name == "hf" | ||
| assert args.dataset_path == "philschmid/mt-bench" | ||
| assert args.num_prompts == 80 | ||
| assert args.temp == 0 | ||
| assert args.top_p == 1.0 | ||
| assert args.top_k == -1 | ||
| assert args.enable_chunked_prefill | ||
|
|
||
| # check acceptance length is within 2% of expected value | ||
| rtol = 0.02 | ||
| expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811 | ||
|
|
||
| assert ( | ||
| acceptance_length <= (1 + rtol) * expected_acceptance_length | ||
| and acceptance_length >= (1 - rtol) * expected_acceptance_length | ||
| ), ( | ||
| f"acceptance_length {acceptance_length} is not " | ||
| f"within {rtol * 100}% of {expected_acceptance_length}" | ||
| ) | ||
|
|
||
| print( | ||
| f"Test passed! Expected AL: " | ||
| f"{expected_acceptance_length}, got {acceptance_length}" | ||
| ) | ||
| spec_decode_main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.